Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9d535d7a
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9d535d7a
编写于
7月 05, 2023
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(xla): improve lower rule
GitOrigin-RevId: 55d43fe0f3666ef233612505a925662161b50bec
上级
d8917c22
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
114 addition
and
25 deletion
+114
-25
imperative/python/megengine/jit/xla_backend.py
imperative/python/megengine/jit/xla_backend.py
+4
-4
imperative/python/megengine/xla/rules/nn.py
imperative/python/megengine/xla/rules/nn.py
+39
-20
imperative/python/megengine/xla/rules/reduction.py
imperative/python/megengine/xla/rules/reduction.py
+3
-1
imperative/python/test/unit/xla/functional/test_xla_nn.py
imperative/python/test/unit/xla/functional/test_xla_nn.py
+68
-0
未找到文件。
imperative/python/megengine/jit/xla_backend.py
浏览文件 @
9d535d7a
...
...
@@ -91,13 +91,13 @@ class xla_trace(trace):
set_use_xla_backend
(
self
.
orig_use_xla
)
def
convert_params_to_xla
(
self
):
from
..device
import
coalesce_free_memory
from
..utils.module_utils
import
get_expand_structure
from
..tensor
import
Tensor
backend
=
self
.
xla_exec
.
backend
devices
=
backend
.
local_devices
()
_
,
device_id
,
_
=
CompNode
(
get_default_device
()).
physical_locator
default_cn
=
CompNode
(
get_default_device
())
_
,
device_id
,
_
=
default_cn
.
physical_locator
device_index
=
(
0
if
len
(
devices
)
==
0
else
[
d
.
id
for
d
in
devices
].
index
(
device_id
)
)
...
...
@@ -114,7 +114,7 @@ class xla_trace(trace):
if
np_array
.
shape
==
():
np_array
=
np_array
[
np
.
newaxis
]
xla_array
=
backend
.
buffer_from_pyval
(
np_array
,
device
)
tensor
.
_reset
(
Tensor
(
xla_array
))
tensor
.
_reset
(
Tensor
(
xla_array
,
device
=
default_cn
))
for
attr
,
_
in
self
.
attr_to_key
.
items
():
param
=
get_expand_structure
(
attr
[
0
],
attr
[
1
])
...
...
@@ -232,7 +232,7 @@ class xla_trace(trace):
return_vals
.
append
(
outputs
[
self
.
outkey2idx
[
i
]])
keeped_features
=
[]
for
i
in
self
.
keeped_activation
:
keeped_features
.
append
(
outputs
[
self
.
outkey2idx
[
i
]]
)
keeped_features
.
append
(
tensor
(
outputs
[
self
.
outkey2idx
[
i
]],
device
=
cn
)
)
out_tensors
=
[]
for
array
in
return_vals
:
if
array
is
not
None
:
...
...
imperative/python/megengine/xla/rules/nn.py
浏览文件 @
9d535d7a
...
...
@@ -49,15 +49,16 @@ def convolution_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
if
opr
.
sparse
==
mops
.
BatchConvBias
.
Sparse
.
DENSE
:
feature_group_count
,
batch_group_count
=
1
,
1
else
:
assert
ic
==
oc
,
"dwconv only support ic == oc"
assert
len
(
weight
.
shape
)
==
5
,
"mge dpconv weight dim is 5"
feature_group_count
,
batch_group_count
=
ic
,
1
feature_group_count
,
batch_group_count
=
weight
.
shape
[
0
]
,
1
if
opr
.
format
==
mops
.
AdaptivePooling
.
Format
.
NCHW
:
assert
(
weight
.
shape
[
1
]
==
1
and
weight
.
shape
[
2
]
==
1
),
f
"weight shape error:
{
weight
.
shape
}
"
xla_weight_shape
=
[
weight
.
shape
[
i
]
for
i
in
[
0
,
2
,
3
,
4
]]
xla_weight_shape
=
xla_weight_shape
=
[
weight
.
shape
[
0
]
*
weight
.
shape
[
1
],
weight
.
shape
[
2
],
weight
.
shape
[
3
],
weight
.
shape
[
4
],
]
weight
=
reshape
(
weight
,
xla_weight_shape
)
feature_group_count
=
ir_utils
.
i64_attr
(
feature_group_count
)
...
...
@@ -159,14 +160,16 @@ def _conv_general_vjp_rhs_padding(
return
list
(
zip
(
pads_lo
,
pads_hi
))
@
register_lower_rule
(
"ConvolutionBackwardDataV2"
)
@
register_lower_rule
(
"ConvolutionBackwardDataV2"
,
mops
.
ConvolutionBackwardData
)
def
conv_backward_data_lower
(
ctx
,
*
args
:
Union
[
HLOTensor
,
Sequence
[
HLOTensor
]]):
assert
len
(
args
)
==
3
and
len
(
ctx
.
vars_out
)
==
1
and
len
(
ctx
.
vars_in
)
==
3
assert
(
ctx
.
param
[
"dilate_h"
]
==
1
and
ctx
.
param
[
"dilate_w"
]
==
1
),
"dilate_conv is not support now"
if
len
(
args
)
==
3
:
weight
,
dout
,
inp
=
args
[
0
],
args
[
1
],
args
[
2
]
else
:
weight
,
dout
,
inp
=
args
[
0
],
args
[
1
],
None
if
ctx
.
param
[
"format"
]
==
mops
.
AdaptivePooling
.
Format
.
NCHW
:
dnums
=
((
0
,
1
,
2
,
3
),
(
0
,
1
,
2
,
3
),
(
0
,
1
,
2
,
3
))
inp_spec
,
weight_spec
,
out_spec
=
dnums
...
...
@@ -177,8 +180,8 @@ def conv_backward_data_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
ph
,
pw
=
ctx
.
param
[
"pad_h"
],
ctx
.
param
[
"pad_w"
]
padding
=
((
ph
,
ph
),
(
pw
,
pw
))
weight_shape
=
weight
.
shape
inp_shape
=
inp
.
shape
ic
=
inp
.
shape
[
1
]
# NCHW
inp_shape
=
inp
.
shape
if
inp
else
ctx
.
vars_out
[
0
].
shape
ic
=
inp
_
shape
[
1
]
# NCHW
oc
=
weight
.
shape
[
0
]
# OIHW or O11HW for dwconv
t_weight_spec
=
(
weight_spec
[
1
],
weight_spec
[
0
])
+
weight_spec
[
2
:]
dnums
=
hlo
.
ConvDimensionNumbers
.
get
(
...
...
@@ -196,11 +199,23 @@ def conv_backward_data_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]):
if
ctx
.
param
[
"sparse"
]
==
mops
.
BatchConvBias
.
Sparse
.
DENSE
:
feature_group_count
,
batch_group_count
=
1
,
1
else
:
assert
ic
==
oc
,
"only support dpwise conv currently"
assert
len
(
weight
.
shape
)
==
5
,
"mge dpconv weight dim is 5"
feature_group_count
,
batch_group_count
=
ic
,
1
weight_shape
=
[
weight
.
shape
[
i
]
for
i
in
[
2
,
0
,
3
,
4
]]
weight_shape
=
weight
.
shape
assert
len
(
weight_shape
)
==
5
,
"mge dpconv weight dim is 5"
feature_group_count
,
batch_group_count
=
weight
.
shape
[
0
],
1
weight_shape
=
[
weight
.
shape
[
1
],
weight
.
shape
[
0
]
*
weight
.
shape
[
2
],
weight
.
shape
[
3
],
weight
.
shape
[
4
],
]
weight
=
weight
.
transpose
((
1
,
0
,
2
,
3
,
4
))
weight
=
weight
.
reshape
(
weight_shape
)
weight_shape
=
[
weight_shape
[
1
],
weight_shape
[
0
],
weight_shape
[
2
],
weight_shape
[
3
],
]
padding
=
_conv_general_vjp_lhs_padding
(
np
.
take
(
inp_shape
,
inp_hw
),
...
...
@@ -262,11 +277,15 @@ def conv_backward_filter_lower(ctx, *args: Union[HLOTensor, Sequence[HLOTensor]]
if
ctx
.
param
[
"sparse"
]
==
mops
.
BatchConvBias
.
Sparse
.
DENSE
:
feature_group_count
,
batch_group_count
=
1
,
1
else
:
assert
ic
==
oc
,
"only support dpwise conv currently"
assert
len
(
weight
.
shape
)
==
5
,
"mge dpconv weight dim is 5"
feature_group_count
,
batch_group_count
=
ic
,
1
weight_shape
=
[
weight
.
shape
[
i
]
for
i
in
[
2
,
0
,
3
,
4
]]
weight_shape
=
weight
.
shape
assert
len
(
weight_shape
)
==
5
,
"mge dpconv weight dim is 5"
feature_group_count
,
batch_group_count
=
weight
.
shape
[
0
],
1
weight_shape
=
[
weight_shape
[
2
],
weight_shape
[
0
]
*
weight_shape
[
1
],
weight_shape
[
3
],
weight_shape
[
4
],
]
if
batch_group_count
>
1
:
feature_group_count
=
batch_group_count
batch_group_count
=
1
...
...
imperative/python/megengine/xla/rules/reduction.py
浏览文件 @
9d535d7a
...
...
@@ -138,7 +138,9 @@ def reduce_lower(ctx, *args: Union[ir.Value, Sequence[ir.Value]]):
else
:
assert
len
(
args
)
==
2
src_shape
=
args
[
0
].
shape
tgt_shape
=
list
(
ctx
.
module_context
.
get_value
(
ctx
.
vars_in
[
1
]))
if
src_shape
==
ctx
.
vars_out
[
0
].
shape
:
return
args
[
0
]
tgt_shape
=
list
(
ctx
.
vars_out
[
0
].
shape
)
tgt_shape
=
[
1
,]
*
(
len
(
src_shape
)
-
len
(
tgt_shape
))
+
tgt_shape
src_idx
,
tgt_idx
,
axes
=
0
,
0
,
[]
while
src_idx
<
len
(
src_shape
)
and
tgt_idx
<
len
(
tgt_shape
):
...
...
imperative/python/test/unit/xla/functional/test_xla_nn.py
浏览文件 @
9d535d7a
...
...
@@ -93,6 +93,74 @@ def test_conv2d():
padding
=
(
2
,
1
),
groups
=
16
,
)
tester
(
(
4
,
16
,
24
,
24
),
(
4
,
4
,
4
,
1
,
1
),
(
1
,
16
,
1
,
1
),
stride
=
(
2
,
3
),
padding
=
(
2
,
1
),
groups
=
4
,
)
@
pytest
.
mark
.
skipif
(
int
(
platform
.
python_version_tuple
()[
1
])
<
8
,
reason
=
"need py38"
)
@
pytest
.
mark
.
skipif
(
platform
.
system
()
!=
"Linux"
,
reason
=
"only support linux now"
)
@
pytest
.
mark
.
skipif
(
not
is_cuda_available
(),
reason
=
"only support cuda now"
)
def
test_conv_transpose2d
():
np
.
random
.
seed
(
123
)
mge
.
random
.
seed
(
123
)
def
tester
(
x_shape
,
w_shape
,
b_shape
,
stride
,
padding
,
groups
,
dtype
=
None
):
dtype
=
dtype
or
np
.
float32
x
=
tensor
(
0.1
*
np
.
random
.
rand
(
*
x_shape
),
dtype
=
dtype
)
w
=
tensor
(
0.1
*
np
.
random
.
rand
(
*
w_shape
),
dtype
=
dtype
)
b
=
tensor
(
0.1
*
np
.
random
.
rand
(
*
b_shape
),
dtype
=
dtype
)
if
b_shape
else
None
y
=
F
.
conv_transpose2d
(
x
,
w
,
b
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
)
dy
=
tensor
(
0.1
*
np
.
random
.
rand
(
*
y
.
shape
),
dtype
=
dtype
)
gm
=
GradManager
()
if
b
is
not
None
:
@
jit
.
xla_trace
(
without_host
=
True
)
def
func
(
x
,
w
,
b
,
dy
):
gm
.
attach
([
x
,
w
,
b
])
with
gm
:
y
=
F
.
conv_transpose2d
(
x
,
w
,
b
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
)
gm
.
backward
(
y
,
dy
)
return
[
y
,
x
.
grad
,
w
.
grad
,
b
.
grad
]
mge_rsts
=
func
(
x
,
w
,
b
,
dy
)
xla_rsts
=
func
(
x
,
w
,
b
,
dy
)
else
:
@
jit
.
xla_trace
(
without_host
=
True
)
def
func
(
x
,
w
,
dy
):
gm
.
attach
([
x
,
w
])
with
gm
:
y
=
F
.
conv2d
(
x
,
w
,
stride
=
stride
,
padding
=
padding
,
groups
=
groups
)
gm
.
backward
(
y
,
dy
)
return
[
y
,
x
.
grad
,
w
.
grad
]
mge_rsts
=
func
(
x
,
w
,
dy
)
xla_rsts
=
func
(
x
,
w
,
dy
)
for
mge_rst
,
xla_rst
in
zip
(
mge_rsts
,
xla_rsts
):
np
.
testing
.
assert_allclose
(
mge_rst
.
numpy
(),
xla_rst
.
numpy
(),
atol
=
1e-4
)
tester
(
(
4
,
16
,
24
,
24
),
(
16
,
32
,
3
,
3
),
(
1
,
32
,
1
,
1
),
stride
=
1
,
padding
=
1
,
groups
=
1
)
tester
(
(
4
,
16
,
24
,
24
),
(
16
,
32
,
3
,
3
),
(
1
,
32
,
1
,
1
),
stride
=
(
2
,
3
),
padding
=
(
2
,
1
),
groups
=
1
,
)
@
pytest
.
mark
.
skipif
(
int
(
platform
.
python_version_tuple
()[
1
])
<
8
,
reason
=
"need py38"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录