Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
e82fa4ec
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
e82fa4ec
编写于
1月 29, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(gopt): using new_inp for build_chain in DelayBroadcast pass
GitOrigin-RevId: efc63771976a35647508b095015cb35a1e7f0c21
上级
a09fc5f7
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
83 addition
and
7 deletion
+83
-7
imperative/python/megengine/functional/nn.py
imperative/python/megengine/functional/nn.py
+3
-3
imperative/python/megengine/tensor.py
imperative/python/megengine/tensor.py
+5
-0
imperative/python/test/integration/test_bn.py
imperative/python/test/integration/test_bn.py
+41
-2
src/gopt/impl/misc.cpp
src/gopt/impl/misc.cpp
+8
-2
src/gopt/test/misc.cpp
src/gopt/test/misc.cpp
+26
-0
未找到文件。
imperative/python/megengine/functional/nn.py
浏览文件 @
e82fa4ec
...
...
@@ -784,10 +784,10 @@ def sync_batch_norm(
if
is_distributed
():
# reduce all nodes' data to calculate mean and variance
reduce_size
=
broadcast_to
(
Tensor
(
reduce_size
,
dtype
=
_dtype
),
[
1
]
*
_ndim
)
stat
=
concat
(
[
reduce_size
.
astype
(
_dtype
),
channel_x1s
,
channel_x2s
],
axis
=
1
reduce_size
=
broadcast_to
(
Tensor
(
reduce_size
).
astype
(
dtype
=
_dtype
),
[
1
]
*
_ndim
)
stat
=
concat
([
reduce_size
,
channel_x1s
,
channel_x2s
],
axis
=
1
)
stat
=
all_reduce_sum
(
stat
,
group
)
reduce_size
=
stat
[:,
:
1
].
reshape
(
1
)
channel_x1s
=
stat
[:,
1
:
1
+
_channels
]
...
...
imperative/python/megengine/tensor.py
浏览文件 @
e82fa4ec
...
...
@@ -18,6 +18,7 @@ from .core._wrap import device as as_device
from
.core.ops.builtin
import
Copy
,
GetVarShape
from
.core.tensor.array_method
import
ArrayMethodMixin
from
.device
import
_valid_device
,
get_default_device
from
.logger
import
get_logger
from
.utils.deprecation
import
deprecated
...
...
@@ -41,6 +42,10 @@ class Tensor(_Tensor, ArrayMethodMixin):
cn
=
device
.
_cn
if
isinstance
(
data
,
_Tensor
):
if
dtype
is
not
None
:
get_logger
().
warning
(
"dtype does not work when creating a new Tensor with another Tensor"
)
obj
=
_Tensor
.
__new__
(
cls
,
data
)
else
:
if
isinstance
(
data
,
np
.
ndarray
):
...
...
imperative/python/test/integration/test_bn.py
浏览文件 @
e82fa4ec
...
...
@@ -17,7 +17,7 @@ import megengine.optimizer as optimizer
from
megengine
import
Parameter
,
tensor
from
megengine.distributed.helper
import
get_device_count_by_fork
from
megengine.jit
import
trace
from
megengine.module
import
BatchNorm2d
,
Module
,
SyncBatchNorm
from
megengine.module
import
BatchNorm2d
,
Conv2d
,
Module
,
Sequential
,
SyncBatchNorm
def
run_frozen_bn
(
BNModule
,
use_trace
=
False
,
use_symbolic
=
False
):
...
...
@@ -68,7 +68,7 @@ def test_frozen_bn():
run_frozen_bn
(
BatchNorm2d
,
True
,
True
)
@
pytest
.
mark
.
skipif
(
get_device_count_by_fork
(
"gpu"
)
<
2
,
reason
=
"need more gpu device"
)
@
pytest
.
mark
.
require_ngpu
(
2
)
@
pytest
.
mark
.
isolated_distributed
def
test_frozen_synced_bn
():
@
dist
.
launcher
(
n_gpus
=
2
)
...
...
@@ -151,6 +151,45 @@ def test_trace_bn_forward_twice():
np
.
testing
.
assert_equal
(
y
.
numpy
(),
0
)
def
run_syncbn
(
trace_mode
):
x
=
F
.
ones
([
2
,
16
,
4
,
4
],
dtype
=
"float32"
)
net
=
Sequential
(
Conv2d
(
16
,
16
,
1
),
SyncBatchNorm
(
16
),
Conv2d
(
16
,
16
,
1
),
SyncBatchNorm
(
16
),
)
gm
=
ad
.
GradManager
().
attach
(
net
.
parameters
(),
callbacks
=
dist
.
make_allreduce_cb
(
"MEAN"
)
)
opt
=
optimizer
.
SGD
(
net
.
parameters
(),
1e-3
)
def
train_func
(
x
):
with
gm
:
y
=
net
(
x
)
loss
=
y
.
mean
()
gm
.
backward
(
loss
)
opt
.
step
().
clear_grad
()
return
loss
if
trace_mode
is
not
None
:
train_func
=
trace
(
train_func
,
symbolic
=
trace_mode
)
for
_
in
range
(
3
):
loss
=
train_func
(
x
)
loss
.
numpy
()
@
pytest
.
mark
.
require_ngpu
(
2
)
@
pytest
.
mark
.
isolated_distributed
@
pytest
.
mark
.
parametrize
(
"trace_mode"
,
[
None
,
True
,
False
])
def
test_trace_several_syncbn
(
trace_mode
):
@
dist
.
launcher
(
n_gpus
=
2
)
def
worker
():
run_syncbn
(
trace_mode
)
worker
()
# https://github.com/MegEngine/MegEngine/issues/145
def
test_frozen_bn_no_affine
():
nchannel
=
3
...
...
src/gopt/impl/misc.cpp
浏览文件 @
e82fa4ec
...
...
@@ -226,8 +226,14 @@ void DelayBroadcastPass::apply(OptState& opt) const {
if
(
!
prev
)
prev
=
rewriter
.
get_var
(
opr
->
input
(
inp_idx
));
if
(
!
opr
->
same_type
<
opr
::
Broadcast
>
())
{
VarNodeArray
new_inp
=
opr
->
input
();
new_inp
.
at
(
inp_idx
)
=
prev
;
VarNodeArray
new_inp
(
opr
->
input
().
size
());
for
(
size_t
i
=
0
;
i
<
opr
->
input
().
size
();
i
++
)
{
if
(
i
==
inp_idx
)
{
new_inp
[
i
]
=
prev
;
}
else
{
new_inp
[
i
]
=
rewriter
.
get_var
(
opr
->
input
(
i
));
}
}
opt
.
call_with_opr
(
opr
,
[
&
]
{
// create new opr with the original opr's properties
auto
new_opr
=
serialization
::
copy_opr_shallow
(
...
...
src/gopt/test/misc.cpp
浏览文件 @
e82fa4ec
...
...
@@ -177,6 +177,32 @@ TEST_PASS(DelayBroadcastPass, LongChain) {
ASSERT_EQ
(
bcast
(
bcast
(
relu
(
relu
(
x
)),
y
),
z
),
out
);
}
TEST_PASS
(
DelayBroadcastPass
,
ElemwiseChain
)
{
auto
typecvt
=
[](
SymbolVar
x
)
{
return
opr
::
TypeCvt
::
make
(
x
,
dtype
::
Int32
());
};
auto
reduce
=
[](
SymbolVar
x
)
{
SymbolVar
tshp
=
x
.
make_scalar
(
1
);
opr
::
Reduce
::
Param
param_default
{
opr
::
Reduce
::
Mode
::
SUM
,
INT_MAX
,
opr
::
Reduce
::
Param
::
DataType
::
DEFAULT
};
return
opr
::
Reduce
::
make
(
x
,
param_default
,
tshp
);
};
auto
shp
=
TensorShape
{
2
,
2
};
auto
x
=
mkvar
(
"x"
,
{
1
,
1
});
auto
val
=
x
.
make_scalar
(
3
);
auto
out
=
reduce
(
typecvt
(
x
.
broadcast
(
shp
)))
+
val
.
broadcast
(
shp
);
out
=
gopt
::
GraphOptimizer
{}.
add_pass
<
gopt
::
DelayBroadcastPass
>
().
apply
({{
out
}}).
endpoint_vars
()[
0
];
auto
expected
=
(
reduce
(
typecvt
(
x
).
broadcast
(
shp
))
+
val
).
broadcast
(
shp
);
ASSERT_EQ
(
out
,
expected
);
}
TEST_PASS
(
ExpandVirtualGradPass
,
Simple
)
{
auto
x
=
mkvar
(
"x"
);
check
(
x
*
2
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录