Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
281ecd0b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
281ecd0b
编写于
7月 07, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(xla): support IndexingMultiAxisVec and IndexingIncrMultiAxisVec
GitOrigin-RevId: ca13d142ef5f6d952350f7217f1aebc2ff644dd6
上级
5e013d8c
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
134 addition
and
3 deletion
+134
-3
imperative/python/megengine/xla/rules/indexing.py
imperative/python/megengine/xla/rules/indexing.py
+97
-3
imperative/python/test/unit/xla/functional/test_xla_indexing.py
...tive/python/test/unit/xla/functional/test_xla_indexing.py
+37
-0
未找到文件。
imperative/python/megengine/xla/rules/indexing.py
浏览文件 @
281ecd0b
...
@@ -409,11 +409,15 @@ def scatter(
...
@@ -409,11 +409,15 @@ def scatter(
oshape
,
odtype
=
oup_var
.
shape
,
oup_var
.
dtype
oshape
,
odtype
=
oup_var
.
shape
,
oup_var
.
dtype
else
:
else
:
oshape
,
odtype
=
x
.
shape
,
x
.
dtype
oshape
,
odtype
=
x
.
shape
,
x
.
dtype
indices
=
(
ir_utils
.
ir_constant
(
indices
)
if
not
isinstance
(
indices
,
HLOTensor
)
else
indices
.
tensor
)
op
=
hlo
.
ScatterOp
(
op
=
hlo
.
ScatterOp
(
ir_utils
.
make_ir_type_according_meta_tuple
(
oshape
,
odtype
),
ir_utils
.
make_ir_type_according_meta_tuple
(
oshape
,
odtype
),
[
x
.
tensor
],
[
x
.
tensor
],
i
r_utils
.
ir_constant
(
indices
)
,
i
ndices
,
[
y
.
tensor
],
[
y
.
tensor
],
scatter_dnums
,
scatter_dnums
,
indices_are_sorted
=
ir
.
BoolAttr
.
get
(
indices_are_sorted
),
indices_are_sorted
=
ir
.
BoolAttr
.
get
(
indices_are_sorted
),
...
@@ -424,7 +428,32 @@ def scatter(
...
@@ -424,7 +428,32 @@ def scatter(
update
=
op
.
update_computation
.
blocks
.
append
(
scalar_type
,
scalar_type
)
update
=
op
.
update_computation
.
blocks
.
append
(
scalar_type
,
scalar_type
)
with
ir
.
InsertionPoint
(
update
):
with
ir
.
InsertionPoint
(
update
):
if
mode
==
"add"
:
add
=
hlo
.
AddOp
(
*
update
.
arguments
)
hlo
.
ReturnOp
(
add
.
results
)
else
:
hlo
.
ReturnOp
((
update
.
arguments
[
1
],))
hlo
.
ReturnOp
((
update
.
arguments
[
1
],))
return
HLOTensor
(
op
.
results
)
def
gather
(
x
,
indices
,
dnums
,
slice_sizes
,
indices_are_sorted
=
False
,
unique_indices
=
False
,
):
gather_dnums
=
hlo
.
GatherDimensionNumbers
.
get
(
collapsed_slice_dims
=
list
(
dnums
.
collapsed_slice_dims
),
index_vector_dim
=
len
(
indices
.
shape
)
-
1
,
offset_dims
=
list
(
dnums
.
offset_dims
),
start_index_map
=
list
(
dnums
.
start_index_map
),
)
op
=
hlo
.
GatherOp
(
x
.
tensor
,
indices
.
tensor
,
gather_dnums
,
indices_are_sorted
=
ir
.
BoolAttr
.
get
(
indices_are_sorted
),
slice_sizes
=
ir_utils
.
dense_int_elements
(
slice_sizes
),
)
return
HLOTensor
(
op
.
results
)
return
HLOTensor
(
op
.
results
)
...
@@ -554,3 +583,68 @@ def indexing_set_one_hot_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]
...
@@ -554,3 +583,68 @@ def indexing_set_one_hot_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]
assert
ctx
.
op
.
ndim
==
args
[
0
].
ndim
,
f
"
{
ctx
.
op
.
ndim
}
,
{
args
[
0
].
shape
}
"
assert
ctx
.
op
.
ndim
==
args
[
0
].
ndim
,
f
"
{
ctx
.
op
.
ndim
}
,
{
args
[
0
].
shape
}
"
return
indexing_set_with_tensor_index
(
args
[
0
],
args
[
2
],
args
[
1
],
ctx
.
op
.
axis
)
return
indexing_set_with_tensor_index
(
args
[
0
],
args
[
2
],
args
[
1
],
ctx
.
op
.
axis
)
def
convert_negative_index
(
indices
:
HLOTensor
,
max_indices
:
int
):
max_i
=
HLOTensor
(
np
.
array
([
max_indices
],
dtype
=
"int32"
))
zero
=
HLOTensor
(
np
.
array
([
0
],
dtype
=
"int32"
))
zeros
=
zero
.
broadcast_to
(
indices
.
shape
)
max_i
=
max_i
.
broadcast_to
(
indices
.
shape
)
positive_indices
=
indices
+
max_i
mask
=
indices
<
zeros
return
HLOTensor
(
hlo
.
SelectOp
(
mask
.
tensor
,
positive_indices
.
tensor
,
indices
.
tensor
).
results
)
@
register_lower_rule
(
mops
.
IndexingMultiAxisVec
)
def
vec_indexing_lower
(
ctx
,
*
args
:
Union
[
HLOTensor
,
Sequence
[
HLOTensor
]]):
assert
len
(
ctx
.
param
[
"items"
])
==
1
axis
,
_
,
_
,
_
,
is_index
=
ctx
.
param
[
"items"
][
0
]
assert
is_index
inp
=
args
[
0
]
indices
=
args
[
1
]
indices
=
convert_negative_index
(
indices
,
inp
.
shape
[
axis
])
offset_dims
=
tuple
(
i
for
i
in
range
(
len
(
inp
.
shape
))
if
i
!=
axis
)
collapsed_slice_dims
=
(
axis
,)
start_index_map
=
(
axis
,)
indices
=
indices
.
reshape
(
indices
.
shape
+
(
1
,))
slices_size
=
tuple
(
(
inp
.
shape
[
i
]
if
i
!=
axis
else
1
for
i
in
range
(
len
(
inp
.
shape
)))
)
return
gather
(
inp
,
indices
,
GatherDimensionNumbers
(
offset_dims
,
collapsed_slice_dims
,
start_index_map
),
slices_size
,
)
@
register_lower_rule
(
mops
.
IndexingIncrMultiAxisVec
)
def
vec_indexing_incr_lower
(
ctx
,
*
args
:
Union
[
HLOTensor
,
Sequence
[
HLOTensor
]]):
assert
len
(
ctx
.
param
[
"items"
])
==
1
axis
,
_
,
_
,
_
,
is_index
=
ctx
.
param
[
"items"
][
0
]
assert
is_index
inp
=
args
[
0
]
indices
=
args
[
2
]
indices
=
convert_negative_index
(
indices
,
inp
.
shape
[
axis
])
indices
=
indices
.
reshape
(
indices
.
shape
+
(
1
,))
y
=
args
[
1
]
offset_dims
=
tuple
(
i
for
i
in
range
(
len
(
inp
.
shape
))
if
i
!=
axis
)
collapsed_slice_dims
=
(
axis
,)
start_index_map
=
(
axis
,)
dnums
=
ScatterDimensionNumbers
(
update_window_dims
=
offset_dims
,
inserted_window_dims
=
collapsed_slice_dims
,
scatter_dims_to_operand_dims
=
start_index_map
,
)
out
=
scatter
(
inp
,
indices
,
y
,
dnums
,
indices_are_sorted
=
False
,
unique_indices
=
False
,
mode
=
"add"
,
)
return
out
imperative/python/test/unit/xla/functional/test_xla_indexing.py
浏览文件 @
281ecd0b
...
@@ -149,3 +149,40 @@ def test_indexing_one_hot():
...
@@ -149,3 +149,40 @@ def test_indexing_one_hot():
tester
((
4
,
8
,
16
),
-
1
,
False
)
tester
((
4
,
8
,
16
),
-
1
,
False
)
tester
((
4
,
1
,
16
),
-
2
,
True
)
tester
((
4
,
1
,
16
),
-
2
,
True
)
tester
((
4
,
1
,
16
),
-
2
,
False
)
tester
((
4
,
1
,
16
),
-
2
,
False
)
@
pytest
.
mark
.
skipif
(
int
(
platform
.
python_version_tuple
()[
1
])
<
8
,
reason
=
"need py38"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
!=
"Linux"
,
reason
=
"only support linux now"
)
@
pytest
.
mark
.
skipif
(
not
is_cuda_available
(),
reason
=
"only support cuda now"
)
def
test_index_multi_vec
():
def
tester
(
x_shape
,
index_type
,
dtype
):
dtype
=
dtype
or
np
.
float32
x
=
tensor
(
np
.
random
.
randn
(
*
x_shape
),
dtype
=
dtype
)
max_val
=
x
.
shape
[
0
]
ind
=
tensor
(
np
.
random
.
randint
(
-
max_val
+
1
,
max_val
,
24
).
astype
(
"int32"
))
gm
=
GradManager
()
rand_num
=
tensor
(
np
.
random
.
random
(
x
[
ind
].
shape
).
astype
(
dtype
))
@
jit
.
xla_trace
(
without_host
=
True
,
capture_as_const
=
True
)
def
func
(
inp
,
ind
):
gm
.
attach
([
inp
])
with
gm
:
x
=
inp
if
index_type
==
"set"
:
x
[
ind
]
=
tensor
(
rand_num
)
else
:
x
=
x
[
ind
]
gm
.
backward
((
x
*
x
).
sum
())
return
x
,
inp
.
grad
mge_rsts
=
func
(
x
,
ind
)
xla_rsts
=
func
(
x
,
ind
)
for
mge_rst
,
xla_rst
in
zip
(
mge_rsts
,
xla_rsts
):
np
.
testing
.
assert_allclose
(
mge_rst
.
numpy
(),
xla_rst
.
numpy
(),
atol
=
1e-5
)
tester
((
3
,
4
,
5
,
6
),
"get"
,
np
.
float32
)
tester
((
3
,
4
,
5
,
6
),
"get"
,
np
.
float16
)
# tester((2,2,2,2), "set", np.float32)
# tester((3,4,5,6), "set", np.float16)
# tester((3,4,5,6), "set", np.float16)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录