Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
2f06d580
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
2f06d580
编写于
7月 24, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(xla): add topk and sort for xla
GitOrigin-RevId: 0e881f30429a8d849ad9cdd0e0f47c3e0921ff97
上级
b0470e73
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
352 addition
and
15 deletion
+352
-15
imperative/python/megengine/xla/ir_utils.py
imperative/python/megengine/xla/ir_utils.py
+5
-0
imperative/python/megengine/xla/rules/indexing.py
imperative/python/megengine/xla/rules/indexing.py
+2
-2
imperative/python/megengine/xla/rules/math.py
imperative/python/megengine/xla/rules/math.py
+193
-2
imperative/python/megengine/xla/rules/tensor.py
imperative/python/megengine/xla/rules/tensor.py
+39
-10
imperative/python/test/unit/xla/functional/test_xla_math.py
imperative/python/test/unit/xla/functional/test_xla_math.py
+106
-1
imperative/python/test/unit/xla/functional/test_xla_tensor.py
...rative/python/test/unit/xla/functional/test_xla_tensor.py
+7
-0
未找到文件。
imperative/python/megengine/xla/ir_utils.py
浏览文件 @
2f06d580
...
...
@@ -192,6 +192,7 @@ class TraceResult:
dtype_to_str
=
{
"float16"
:
"f16"
,
"float32"
:
"f32"
,
"int8"
:
"i8"
,
"int32"
:
"i32"
,
"int64"
:
"i64"
,
"uint8"
:
"u8"
,
...
...
@@ -417,6 +418,10 @@ def f32_attr(i):
return
ir
.
FloatAttr
.
get
(
ir
.
F32Type
.
get
(),
i
)
def
bool_attr
(
i
):
return
ir
.
BoolAttr
.
get
(
i
)
def
precision_attr
(
lhs_prec
,
rhs_prec
)
->
ir
.
ArrayAttr
:
lhs_prec
=
str
(
lhs_prec
)
rhs_prec
=
str
(
rhs_prec
)
...
...
imperative/python/megengine/xla/rules/indexing.py
浏览文件 @
2f06d580
...
...
@@ -66,7 +66,7 @@ def _hslice_with_step_is_one(inp, slices):
def
_hslice_with_any_step
(
inp
,
slices
):
"""
if inp_shape is N-dim, slices should contain N slice, slice can not None
if inp_shape is N-dim, slices should contain N slice, slice can not None
.
for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)]
"""
starts
=
[
int
(
sl
.
start
)
for
sl
in
slices
]
...
...
@@ -83,7 +83,7 @@ def _hslice_with_any_step(inp, slices):
def
index_with_slices
(
inp
,
slices
):
"""
if inp_shape is N-dim, slices should contain N slice, slice can be None
if inp_shape is N-dim, slices should contain N slice, slice can be None
.
for shape [12, 15], slices can be [slice(0, 3, 1), slice(12, 15, 1)] or [None, None]
"""
assert
isinstance
(
slices
,
Sequence
),
f
"
{
slices
}
"
...
...
imperative/python/megengine/xla/rules/math.py
浏览文件 @
2f06d580
...
...
@@ -4,9 +4,13 @@ import numpy as np
from
...core._imperative_rt
import
ops
as
mops
from
..
import
ir_utils
from
..ir_utils
import
i64_attr
from
..ir_utils
import
bool_attr
,
i64_attr
from
..lib.mlir
import
ir
from
..lib.mlir.dialects
import
chlo
,
hlo
from
..utils
import
flatten_list
from
.hlotensor
import
HLOTensor
from
.indexing
import
ScatterDimensionNumbers
,
scatter
from
.tensor
import
concat
,
expand_dims
,
fill
,
iota
from
.utils
import
_can_broadcast_to
,
_shape_equal
,
register_lower_rule
...
...
@@ -241,5 +245,192 @@ def batched_matmul_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
).
transpose
(
permutation
)
def
_sort_according_to_key
(
key
,
*
vals
,
axis
=-
1
,
descending
=
True
,
is_stable
=
True
):
"""
sort key and vals in the specified axis, return the sorted key and vals.
key and vals should have the same shape, then we reorder both key and vals according
to the value of the key.
example 1: (implement argsort)
inp: 1.7783 -> 0, -1.8184 -> 1, 1.0701 -> 2
[[ 1.7783 -1.8184 1.0701]
[-0.0712 -1.4623 1.3243]]
[[0 1 2]
[0 1 2]]
axis: -1
descend: True
return: after reorder, 1.7783 -> 0, -1.8184 -> 1, 1.0701 -> 2
[[ 1.7783 1.0701 -1.8184]
[ 1.3243 -0.0712 -1.4623]]
[[0 2 1]
[2 0 1]]
example 2:
inp:
[[0 2 1]
[2 0 1]]
[[ 1.7783 1.0701 -1.8184]
[ 1.3243 -0.0712 -1.4623]]
axis: -1
descend: False
return:
[[0 1 2]
[0 1 2]]
[[ 1.7783 -1.8184 1.0701]
[-0.0712 -1.4623 1.3243]]
"""
for
val
in
vals
:
assert
_shape_equal
(
key
.
shape
,
val
.
shape
),
f
"sort key and vals shape mismatch:
{
key
.
shape
}
,
{
val
.
shape
}
"
axis
=
axis
+
key
.
ndim
if
axis
<
0
else
axis
sorted_key
=
ir_utils
.
make_ir_type_according_meta
(
key
.
shape
,
key
.
dtype
)
sorted_vals
=
[
ir_utils
.
make_ir_type_according_meta
(
val
.
shape
,
val
.
dtype
)
for
val
in
vals
]
sort_op
=
hlo
.
SortOp
(
[
sorted_key
,
*
sorted_vals
],
[
key
.
tensor
,
*
[
val
.
tensor
for
val
in
vals
]],
dimension
=
i64_attr
(
axis
),
is_stable
=
bool_attr
(
is_stable
),
)
key_type
=
ir_utils
.
make_ir_type_according_meta
(
tuple
(),
key
.
dtype
)
val_types
=
[
ir_utils
.
make_ir_type_according_meta
(
tuple
(),
val
.
dtype
)
for
val
in
vals
]
arg_types
=
[
key_type
]
+
val_types
comparator
=
sort_op
.
comparator
.
blocks
.
append
(
*
flatten_list
(
zip
(
arg_types
,
arg_types
))
)
with
ir
.
InsertionPoint
(
comparator
):
lhs
=
HLOTensor
(
comparator
.
arguments
[
0
])
rhs
=
HLOTensor
(
comparator
.
arguments
[
1
])
if
descending
:
hlo
.
ReturnOp
([(
lhs
>
rhs
).
tensor
])
else
:
hlo
.
ReturnOp
([(
lhs
<
rhs
).
tensor
])
assert
len
(
sort_op
.
results
)
==
len
(
vals
)
+
1
,
f
"
{
len
(
vals
)
}
,
{
len
(
sort_op
.
results
)
}
"
return
(
HLOTensor
(
ret
)
for
ret
in
sort_op
.
results
)
def
argsort
(
inp
,
axis
=-
1
,
descending
=
True
,
is_stable
=
True
):
"""
sort inp in the specfic axis, and return the sorted value and index
for example:
inp:
[[ 1.7783 -1.8184 1.0701]
[-0.0712 -1.4623 1.3243]]
axis: -1
descend: True
return:
[[ 1.7783 1.0701 -1.8184]
[ 1.3243 -0.0712 -1.4623]]
[[0 2 1]
[2 0 1]]
"""
axis
=
axis
+
inp
.
ndim
if
axis
<
0
else
axis
idx
=
iota
(
np
.
int32
,
inp
.
shape
,
axis
)
return
_sort_according_to_key
(
inp
,
idx
,
axis
=
axis
,
descending
=
descending
,
is_stable
=
is_stable
)
@
register_lower_rule
(
mops
.
Argsort
)
def
argsort_lower
(
ctx
,
*
args
:
Union
[
HLOTensor
,
Sequence
[
HLOTensor
]]):
assert
(
len
(
args
)
==
1
and
len
(
ctx
.
vars_in
)
==
1
and
len
(
ctx
.
vars_out
)
==
2
),
f
"
{
len
(
args
)
}
,
{
len
(
ctx
.
vars_in
)
}
,
{
len
(
ctx
.
vars_out
)
}
"
assert
ctx
.
op
.
order
in
[
mops
.
Argsort
.
Order
.
DESCENDING
,
mops
.
Argsort
.
Order
.
ASCENDING
,
],
f
"
{
ctx
.
op
.
order
}
"
descending
=
ctx
.
op
.
order
==
mops
.
Argsort
.
Order
.
DESCENDING
axis
=
args
[
0
].
ndim
-
1
# megengine only support sort in the last dimension
return
argsort
(
args
[
0
],
axis
,
descending
,
is_stable
=
True
)
@
register_lower_rule
(
"ArgsortBackward"
)
def
argsort_backward_lower
(
ctx
,
*
args
:
Union
[
HLOTensor
,
Sequence
[
HLOTensor
]]):
assert
(
len
(
args
)
==
3
and
len
(
ctx
.
vars_in
)
==
3
and
len
(
ctx
.
vars_out
)
==
1
),
f
"
{
len
(
args
)
}
,
{
len
(
ctx
.
vars_in
)
}
,
{
len
(
ctx
.
vars_out
)
}
"
dy
,
idx
,
x
=
args
[
0
],
args
[
1
],
args
[
2
]
if
_shape_equal
(
x
.
shape
,
dy
.
shape
):
# for argsort backward
_
,
dx
=
_sort_according_to_key
(
idx
,
dy
,
axis
=-
1
,
descending
=
False
,
is_stable
=
True
)
else
:
# for topk backward, only support axis=-1 and the dx is 2d tensor
dx
=
fill
(
0
,
ctx
.
vars_out
[
0
].
shape
,
ctx
.
vars_out
[
0
].
dtype
)
expander
=
iota
(
np
.
int32
,
idx
.
shape
,
dimension
=
0
)
idx
=
expand_dims
(
idx
,
-
1
)
expander
=
expand_dims
(
expander
,
-
1
)
idx
=
concat
([
expander
,
idx
],
-
1
)
dnums
=
ScatterDimensionNumbers
(
update_window_dims
=
(),
inserted_window_dims
=
(
0
,
1
),
scatter_dims_to_operand_dims
=
(
0
,
1
),
)
dx
=
scatter
(
dx
,
idx
,
dy
,
dnums
,
unique_indices
=
True
)
return
dx
def
topk
(
inp
,
k
,
descending
=
True
,
kth_only
=
False
,
no_sort
=
False
):
return
[
HLOTensor
(
rst
)
for
rst
in
chlo
.
TopKOp
(
inp
.
tensor
,
i64_attr
(
k
)).
results
]
"""
do topk in the last dimension of inp, for example:
inp.shape = (2, 3, 4), k = 2, out_shape = (2, 3, 2)
"""
assert
k
>
0
,
f
"k of topk must bigger than 0, get
{
k
}
"
assert
no_sort
==
False
,
f
"no_sort must be False now"
assert
kth_only
==
False
,
f
"kth_only is not support now"
if
descending
==
True
:
out
,
idx
=
[
HLOTensor
(
rst
)
for
rst
in
chlo
.
TopKOp
(
inp
.
tensor
,
i64_attr
(
k
)).
results
]
else
:
inp
=
-
inp
out
,
idx
=
[
HLOTensor
(
rst
)
for
rst
in
chlo
.
TopKOp
(
inp
.
tensor
,
i64_attr
(
k
)).
results
]
out
=
-
out
return
out
,
idx
@
register_lower_rule
(
mops
.
TopK
)
def
topk_lower
(
ctx
,
*
args
:
Union
[
HLOTensor
,
Sequence
[
HLOTensor
]]):
assert
(
len
(
args
)
==
2
and
len
(
ctx
.
vars_in
)
==
2
),
f
"
{
len
(
args
)
}
,
{
len
(
ctx
.
vars_in
)
}
,
{
len
(
ctx
.
vars_out
)
}
"
assert
isinstance
(
ctx
.
vars_in
[
1
].
bound_data
,
np
.
ndarray
),
f
"
{
ctx
.
vars_in
[
1
].
bound_data
}
"
k
=
int
(
ctx
.
vars_in
[
1
].
bound_data
)
descending
=
True
if
k
<
0
else
False
k
=
-
k
if
k
<
0
else
k
if
ctx
.
op
.
mode
==
mops
.
TopK
.
Mode
.
VALUE_IDX_SORTED
:
assert
len
(
ctx
.
vars_out
)
==
2
,
f
"
{
len
(
ctx
.
vars_out
)
}
"
kth_only
,
no_sort
=
False
,
False
elif
ctx
.
op
.
mode
==
mops
.
TopK
.
Mode
.
VALUE_IDX_NOSORT
:
assert
len
(
ctx
.
vars_out
)
==
2
,
f
"
{
len
(
ctx
.
vars_out
)
}
"
kth_only
,
no_sort
=
False
,
True
else
:
assert
(
ctx
.
op
.
mode
==
mops
.
TopK
.
Mode
.
KTH_ONLY
),
f
"invalid mode for topk,
{
ctx
.
op
.
mode
}
"
kth_only
,
no_sort
=
True
,
False
assert
len
(
ctx
.
vars_out
)
==
1
,
f
"
{
len
(
ctx
.
vars_out
)
}
"
return
topk
(
args
[
0
],
k
,
descending
,
kth_only
,
no_sort
)
imperative/python/megengine/xla/rules/tensor.py
浏览文件 @
2f06d580
...
...
@@ -79,14 +79,13 @@ def transpose(inp, permutation):
def
expand_dims
(
inp
,
axis
):
assert
isinstance
(
axis
,
int
),
f
"only int axis supported, get
{
axis
}
"
axis
=
(
axis
+
inp
.
ndim
)
if
axis
<
0
else
axis
assert
axis
>=
0
and
axis
<=
inp
.
ndim
,
f
"invalid axis
{
axis
}
for
{
inp
.
shape
}
"
assert
(
axis
>=
-
inp
.
ndim
-
1
and
axis
<=
inp
.
ndim
),
f
"invalid axis
{
axis
}
for
{
inp
.
shape
}
"
dst_shape
=
[]
for
i
in
range
(
inp
.
ndim
):
if
i
==
axis
:
dst_shape
.
append
(
1
)
dst_shape
.
append
(
inp
.
shape
[
i
])
dst_shape
=
list
(
inp
.
shape
)
insert_pos
=
axis
if
axis
>=
0
else
(
axis
+
inp
.
ndim
+
1
)
dst_shape
.
insert
(
insert_pos
,
1
)
return
inp
.
reshape
(
tuple
(
dst_shape
))
...
...
@@ -94,14 +93,29 @@ def expand_dims(inp, axis):
@
register_lower_rule
(
mops
.
Dimshuffle
)
def
dim_shuffle_lower
(
ctx
,
*
args
:
Union
[
HLOTensor
,
Sequence
[
HLOTensor
]]):
assert
len
(
args
)
==
1
and
len
(
ctx
.
vars_in
)
==
1
and
len
(
ctx
.
vars_out
)
==
1
permutation
=
ctx
.
op
.
pattern
return
transpose
(
args
[
0
],
permutation
)
# mge dimshuffle can do transpose and broadcast simutaneously
# for example:
# case1: (16, 32, 64) with pattern [0, 2, 1] -> (16, 64, 32)
# case2: (16, 32, 64) with pattern [0, -1, 2, -1, 1] -> (16, 1, 64, 1, 32)
# case3: (16, 1, 64, 1, 32) with pattern [0, 4, 2] -> (16, 32, 64)
pattern
=
ctx
.
op
.
pattern
inp
=
args
[
0
]
if
len
(
pattern
)
==
inp
.
ndim
:
permutation
=
pattern
return
transpose
(
inp
,
permutation
)
elif
len
(
pattern
)
>
inp
.
ndim
:
permutation
=
[
item
for
item
in
pattern
if
item
!=
-
1
]
return
transpose
(
inp
,
permutation
).
reshape
(
ctx
.
vars_out
[
0
].
shape
)
else
:
permutation
=
[
i
for
i
in
range
(
inp
.
ndim
)
if
i
not
in
pattern
]
+
list
(
pattern
)
return
transpose
(
inp
,
permutation
).
reshape
(
ctx
.
vars_out
[
0
].
shape
)
def
concat
(
inps
,
axis
):
assert
len
(
inps
)
>
0
,
f
"concat inputs should not be empty"
if
axis
<
0
:
axis
=
axis
+
inps
[
0
].
ndim
[
0
]
axis
=
axis
+
inps
[
0
].
ndim
hlo_inps
=
[
inp
.
tensor
for
inp
in
inps
]
...
...
@@ -175,6 +189,21 @@ def fill(value, shape, dtype):
return
broadcast_to
(
HLOTensor
(
value
,
dtype
=
dtype
),
shape
)
def
iota
(
dtype
,
shape
,
dimension
):
"""
do some thing like arange.
for example:
shape = (2, 3), dimension=1, output is [[0, 1, 2], [0, 1, 2]]
shape = (2, 3), dimension=-1, output is [[0, 0, 0], [1, 1, 1]]
"""
dimension
=
dimension
+
len
(
shape
)
if
dimension
<
0
else
dimension
ret
=
hlo
.
IotaOp
(
ir_utils
.
make_ir_type_according_meta
(
shape
,
dtype
),
ir_utils
.
i64_attr
(
dimension
)
).
results
assert
len
(
ret
)
==
1
,
f
"
{
len
(
ret
)
}
"
return
HLOTensor
(
ret
[
0
])
@
register_lower_rule
(
mops
.
Fill
)
def
fill_lower
(
ctx
,
*
args
:
Union
[
HLOTensor
,
Sequence
[
HLOTensor
]]):
assert
len
(
args
)
==
1
and
len
(
ctx
.
vars_in
)
==
1
and
len
(
ctx
.
vars_out
)
==
1
...
...
imperative/python/test/unit/xla/functional/test_xla_math.py
浏览文件 @
2f06d580
...
...
@@ -31,7 +31,6 @@ def test_matmul():
return
out
,
lhs
.
grad
,
rhs
.
grad
mge_rsts
=
func
(
lhs
,
rhs
,
dout
)
mge_rsts
[
0
].
numpy
()
xla_rsts
=
func
(
lhs
,
rhs
,
dout
)
for
mge_rst
,
xla_rst
in
zip
(
mge_rsts
,
xla_rsts
):
...
...
@@ -79,3 +78,109 @@ def test_matmul():
tester
((
1
,
2
,
8
,
7
),
(
4
,
2
,
2
,
9
,
8
),
True
,
True
)
tester
((
1
,
8
,
7
),
(
4
,
3
,
2
,
8
,
9
),
True
,
False
)
tester
((
1
,
8
,
7
),
(
4
,
3
,
1
,
9
,
8
),
True
,
True
)
@
pytest
.
mark
.
skipif
(
int
(
platform
.
python_version_tuple
()[
1
])
<
8
,
reason
=
"need py38"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
!=
"Linux"
,
reason
=
"only support linux now"
)
@
pytest
.
mark
.
skipif
(
not
is_cuda_available
(),
reason
=
"only support cuda now"
)
def
test_sort_and_argsort
():
def
tester
(
ishape
,
descending
,
dtype
=
None
):
dtype
=
dtype
or
np
.
float32
inp1
=
tensor
(
np
.
random
.
randn
(
*
ishape
),
dtype
=
dtype
)
inp2
=
tensor
(
np
.
random
.
randn
(
*
ishape
),
dtype
=
dtype
)
dout
=
tensor
(
np
.
random
.
randn
(
*
ishape
),
dtype
=
dtype
)
gm
=
GradManager
()
@
jit
.
xla_trace
(
without_host
=
True
)
def
func
(
inp1
,
inp2
,
dout
):
gm
.
attach
([
inp1
,
inp2
])
with
gm
:
out
,
idx1
=
F
.
sort
(
inp1
,
descending
)
idx2
=
F
.
argsort
(
inp2
,
-
descending
)
gm
.
backward
(
out
,
dout
)
return
out
,
idx1
,
idx2
,
inp1
.
grad
mge_rsts
=
func
(
inp1
,
inp2
,
dout
)
xla_rsts
=
func
(
inp1
,
inp2
,
dout
)
for
mge_rst
,
xla_rst
in
zip
(
mge_rsts
,
xla_rsts
):
np
.
testing
.
assert_allclose
(
mge_rst
.
numpy
(),
xla_rst
.
numpy
(),
atol
=
1e-5
)
for
descending
in
[
True
,
False
]:
tester
((
16
,
32
),
descending
)
tester
((
16
,
1
),
descending
)
tester
((
1
,
16
),
descending
)
tester
((
1
,
1
),
descending
)
tester
((
16
,),
descending
)
tester
((
1
,),
descending
)
@
pytest
.
mark
.
skipif
(
int
(
platform
.
python_version_tuple
()[
1
])
<
8
,
reason
=
"need py38"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
!=
"Linux"
,
reason
=
"only support linux now"
)
@
pytest
.
mark
.
skipif
(
not
is_cuda_available
(),
reason
=
"only support cuda now"
)
def
test_topk
():
def
tester
(
ishape
,
k
,
descending
,
kth_only
,
no_sort
,
dtype
=
None
):
dtype
=
dtype
or
np
.
float32
inp
=
tensor
(
np
.
random
.
randn
(
*
ishape
),
dtype
=
dtype
)
out
,
_
=
F
.
topk
(
inp
,
k
,
descending
,
kth_only
,
no_sort
)
dout
=
tensor
(
0.1
*
np
.
random
.
randn
(
*
out
.
shape
),
dtype
=
dtype
)
gm
=
GradManager
()
@
jit
.
xla_trace
(
without_host
=
True
)
def
func
(
inp
,
dout
):
gm
.
attach
([
inp
])
with
gm
:
out
,
index
=
F
.
topk
(
inp
,
k
,
descending
,
kth_only
,
no_sort
)
gm
.
backward
(
out
,
dout
)
return
out
,
index
,
inp
.
grad
mge_rsts
=
func
(
inp
,
dout
)
xla_rsts
=
func
(
inp
,
dout
)
for
mge_rst
,
xla_rst
in
zip
(
mge_rsts
,
xla_rsts
):
np
.
testing
.
assert_allclose
(
mge_rst
.
numpy
(),
xla_rst
.
numpy
(),
atol
=
1e-5
)
for
descending
in
[
True
,
False
]:
tester
((
2
,
16
,),
1
,
descending
,
False
,
False
)
tester
((
2
,
16
,),
8
,
descending
,
False
,
False
)
tester
((
1
,
16
,),
1
,
descending
,
False
,
False
)
tester
((
1
,
16
,),
5
,
descending
,
False
,
False
)
tester
((
16
,),
8
,
descending
,
False
,
False
)
tester
((
16
,),
8
,
descending
,
False
,
False
)
tester
((
1
,),
1
,
descending
,
False
,
False
)
tester
((
1
,),
1
,
descending
,
False
,
False
)
@
pytest
.
mark
.
skipif
(
int
(
platform
.
python_version_tuple
()[
1
])
<
8
,
reason
=
"need py38"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
!=
"Linux"
,
reason
=
"only support linux now"
)
@
pytest
.
mark
.
skipif
(
not
is_cuda_available
(),
reason
=
"only support cuda now"
)
def
test_topk_accuracy
():
def
tester
(
batch
,
nr_class
,
topk
,
dtype
=
None
):
dtype
=
dtype
or
np
.
float32
logits
=
tensor
(
np
.
random
.
uniform
(
0
,
1
,
(
batch
,
nr_class
)),
dtype
=
dtype
)
target
=
tensor
(
np
.
random
.
randint
(
0
,
nr_class
,
(
batch
,),
np
.
int32
))
out
=
F
.
topk_accuracy
(
logits
,
target
,
topk
)
dout
=
tensor
(
0.1
*
np
.
random
.
randn
(
*
out
.
shape
),
dtype
=
dtype
)
gm
=
GradManager
()
@
jit
.
xla_trace
(
without_host
=
True
)
def
func
(
logits
,
target
,
dout
):
gm
.
attach
([
logits
])
with
gm
:
out
=
F
.
topk_accuracy
(
logits
,
target
,
topk
)
gm
.
backward
(
out
,
dout
)
return
[
out
]
mge_rsts
=
func
(
logits
,
target
,
dout
)
xla_rsts
=
func
(
logits
,
target
,
dout
)
for
mge_rst
,
xla_rst
in
zip
(
mge_rsts
,
xla_rsts
):
np
.
testing
.
assert_allclose
(
mge_rst
.
numpy
(),
xla_rst
.
numpy
(),
atol
=
1e-5
)
tester
(
32
,
1000
,
10
)
tester
(
32
,
1
,
1
)
tester
(
1
,
1000
,
10
)
tester
(
1
,
1
,
1
)
imperative/python/test/unit/xla/functional/test_xla_tensor.py
浏览文件 @
2f06d580
...
...
@@ -113,6 +113,13 @@ def test_transpose():
tester
((
2
,
3
,
1
),
(
0
,
1
,
2
))
tester
((
2
,
3
,
1
,
4
),
(
3
,
1
,
0
,
2
))
tester
((
1
,),
(
"x"
,
0
))
# tester((1,), (0, 'x')) # bug for mge
tester
((
1
,
2
),
(
"x"
,
0
,
1
))
tester
((
1
,
2
),
(
0
,
"x"
,
1
))
# tester((1, 2), (0, 1, 'x')) # bug for mge
tester
((
16
,
32
,
64
),
(
0
,
"x"
,
2
,
"x"
,
1
))
@
pytest
.
mark
.
skipif
(
int
(
platform
.
python_version_tuple
()[
1
])
<
8
,
reason
=
"need py38"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
!=
"Linux"
,
reason
=
"only support linux now"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录