Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
40412e26
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
40412e26
编写于
5月 29, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/module): add sync bn
GitOrigin-RevId: ae71a540d1ee044a5879ad029479ed19bc99cfb8
上级
3c32ad6d
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
469 addition
and
59 deletion
+469
-59
python_module/megengine/distributed/functional.py
python_module/megengine/distributed/functional.py
+4
-46
python_module/megengine/distributed/helper.py
python_module/megengine/distributed/helper.py
+53
-0
python_module/megengine/distributed/util.py
python_module/megengine/distributed/util.py
+10
-0
python_module/megengine/functional/__init__.py
python_module/megengine/functional/__init__.py
+1
-0
python_module/megengine/functional/nn.py
python_module/megengine/functional/nn.py
+125
-1
python_module/megengine/module/__init__.py
python_module/megengine/module/__init__.py
+1
-1
python_module/megengine/module/batchnorm.py
python_module/megengine/module/batchnorm.py
+49
-2
python_module/megengine/optimizer/optimizer.py
python_module/megengine/optimizer/optimizer.py
+10
-8
python_module/test/unit/module/test_batchnorm.py
python_module/test/unit/module/test_batchnorm.py
+216
-1
未找到文件。
python_module/megengine/distributed/functional.py
浏览文件 @
40412e26
...
...
@@ -12,56 +12,14 @@ import megengine._internal as mgb
from
megengine._internal.opr_param_defs
import
CollectiveComm
as
CollParam
from
..core
import
Buffer
,
Parameter
,
Tensor
,
wrap_io_tensor
from
..core.graph
import
get_default_graph
from
..functional
import
add_update
from
.util
import
(
get_backend
,
get_master_ip
,
get_master_port
,
get_rank
,
get_world_size
,
is_distributed
,
)
from
.helper
import
collective_comm_symvar
from
.util
import
get_rank
,
is_distributed
@
wrap_io_tensor
def
_collective_comm
(
inp
:
Union
[
Tensor
,
mgb
.
CompGraph
],
key
:
str
,
op
:
CollParam
.
Mode
,
nr_ranks
:
Optional
[
int
]
=
None
,
rank
:
Optional
[
int
]
=
None
,
root
:
Optional
[
int
]
=
0
,
dtype
:
Optional
[
type
]
=
None
,
device
:
Optional
[
mgb
.
CompNode
]
=
None
,
comp_graph
:
Optional
[
mgb
.
CompGraph
]
=
None
,
)
->
Tensor
:
"""Helper function for creating collective_comm operators
:param inp: tensor or comp_graph
:param key: unique identifier for collective communication
:param op: mode of collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param root: rank of root node, use 0 as default
:param dtype: output data type, use dtype of inp as default
:param device: output comp node, use comp node of inp as default
:param comp_graph: output comp graph, use comp graph of inp as default
"""
return
mgb
.
opr
.
collective_comm
(
inp
,
key
=
str
(
key
),
nr_devices
=
nr_ranks
if
nr_ranks
is
not
None
else
get_world_size
(),
rank
=
rank
if
rank
is
not
None
else
get_rank
(),
root
=
root
,
server_addr
=
get_master_ip
(),
port
=
get_master_port
(),
param
=
CollParam
(
mode
=
op
),
dtype
=
dtype
,
backend
=
get_backend
(),
comp_node
=
device
,
comp_graph
=
comp_graph
,
)
def
_collective_comm
(
*
args
,
**
kargs
):
return
collective_comm_symvar
(
*
args
,
**
kargs
)
def
reduce_sum
(
...
...
python_module/megengine/distributed/helper.py
0 → 100644
浏览文件 @
40412e26
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
typing
import
Optional
,
Union
import
megengine._internal
as
mgb
from
megengine._internal.opr_param_defs
import
CollectiveComm
as
CollParam
from
.util
import
get_backend
,
get_master_ip
,
get_master_port
,
get_rank
,
get_world_size
def
collective_comm_symvar
(
inp
:
Union
[
mgb
.
SymbolVar
,
mgb
.
CompGraph
],
key
:
str
,
op
:
CollParam
.
Mode
,
nr_ranks
:
Optional
[
int
]
=
None
,
rank
:
Optional
[
int
]
=
None
,
root
:
Optional
[
int
]
=
0
,
dtype
:
Optional
[
type
]
=
None
,
device
:
Optional
[
mgb
.
CompNode
]
=
None
,
comp_graph
:
Optional
[
mgb
.
CompGraph
]
=
None
,
)
->
mgb
.
SymbolVar
:
"""Helper function for creating collective_comm operators
:param inp: tensor or comp_graph
:param key: unique identifier for collective communication
:param op: mode of collective communication
:param nr_ranks: number of ranks, use util.get_world_size() as default
:param rank: rank of the current process, use util.get_rank() as default
:param root: rank of root node, use 0 as default
:param dtype: output data type, use dtype of inp as default
:param device: output comp node, use comp node of inp as default
:param comp_graph: output comp graph, use comp graph of inp as default
"""
return
mgb
.
opr
.
collective_comm
(
inp
,
key
=
str
(
key
),
nr_devices
=
nr_ranks
if
nr_ranks
is
not
None
else
get_world_size
(),
rank
=
rank
if
rank
is
not
None
else
get_rank
(),
root
=
root
,
server_addr
=
get_master_ip
(),
port
=
get_master_port
(),
param
=
CollParam
(
mode
=
op
),
dtype
=
dtype
,
backend
=
get_backend
(),
comp_node
=
device
,
comp_graph
=
comp_graph
,
)
python_module/megengine/distributed/util.py
浏览文件 @
40412e26
...
...
@@ -19,6 +19,7 @@ _master_port = 0
_world_size
=
0
_rank
=
0
_backend
=
None
_group_id
=
0
def
init_process_group
(
...
...
@@ -43,6 +44,7 @@ def init_process_group(
global
_world_size
# pylint: disable=global-statement
global
_rank
# pylint: disable=global-statement
global
_backend
# pylint: disable=global-statement
global
_group_id
# pylint: disable=global-statement
if
not
isinstance
(
master_ip
,
str
):
raise
TypeError
(
"Expect type str but got {}"
.
format
(
type
(
master_ip
)))
...
...
@@ -60,6 +62,7 @@ def init_process_group(
_world_size
=
world_size
_rank
=
rank
_backend
=
backend
_group_id
=
0
set_default_device
(
mgb
.
comp_node
(
"gpu"
+
str
(
dev
)))
...
...
@@ -101,6 +104,13 @@ def get_backend() -> str:
return
str
(
_backend
)
def
get_group_id
()
->
int
:
"""Get group id for collective communication"""
global
_group_id
_group_id
+=
1
return
_group_id
def
group_barrier
()
->
None
:
"""Block until all ranks in the group reach this barrier"""
mgb
.
config
.
group_barrier
(
_master_ip
,
_master_port
,
_world_size
,
_rank
)
...
...
python_module/megengine/functional/__init__.py
浏览文件 @
40412e26
...
...
@@ -76,6 +76,7 @@ from .nn import (
roi_pooling
,
softmax
,
softplus
,
sync_batch_norm
,
warp_perspective
,
)
from
.quantized
import
conv_bias_activation
...
...
python_module/megengine/functional/nn.py
浏览文件 @
40412e26
...
...
@@ -11,15 +11,20 @@ from typing import Optional, Tuple, Union
import
megengine._internal
as
mgb
from
megengine._internal
import
CompGraph
,
CompNode
from
megengine._internal.config
import
add_extra_vardep
from
megengine._internal.opr
import
add_update
from
megengine._internal.opr_param_defs
import
CollectiveComm
as
CollParam
from
..
import
distributed
as
dist
from
..core
import
Tensor
,
wrap_io_tensor
from
..core.graph
import
_use_default_if_none
from
..distributed.util
import
get_group_id
from
..jit
import
barrier
,
mark_impure
from
..random
import
uniform
from
..utils.types
import
_pair
,
_pair_nonzero
from
.debug_param
import
get_conv_execution_strategy
from
.elemwise
import
exp
,
log
from
.tensor
import
concat
,
where
from
.tensor
import
where
from
.utils
import
_decide_comp_node_and_comp_graph
...
...
@@ -474,6 +479,125 @@ def batch_norm2d(
return
output
@
wrap_io_tensor
def
sync_batch_norm
(
input
:
Tensor
,
running_mean
:
Tensor
,
running_var
:
Tensor
,
weight
:
Optional
[
Tensor
]
=
None
,
bias
:
Optional
[
Tensor
]
=
None
,
training
:
bool
=
False
,
momentum
:
Union
[
float
,
Tensor
]
=
0.9
,
eps
:
float
=
1e-5
,
eps_mode
=
"ADDITIVE"
,
)
->
Tensor
:
""" Applies synchronized batch normalization to the input.
Refer to :class:`~.BatchNorm2d` and :class:`~.BatchNorm1d` for more information.
:param inp: input tensor.
:param running_mean: tensor to store running mean.
:param running_var: tensor to store running variance.
:param weight: scaling tensor in the learnable affine parameters.
See :math:`\gamma` in :class:`~.BatchNorm2d`
:param bias: bias tensor in the learnable affine parameters.
See :math:`
\b
eta` in :class:`~.BatchNorm2d`
:param training: a boolean value to indicate whether batch norm is performed
in traning mode. Default: ``False``
:param momentum: the value used for the ``running_mean`` and ``running_var``
computation.
Default: 0.9
:param eps: a value added to the denominator for numerical stability.
Default: 1e-5.
"""
assert
eps_mode
in
{
"MAX"
,
"ADDITIVE"
},
"unknown eps_mode: {}"
.
format
(
eps_mode
)
input
=
mgb
.
opr
.
mark_no_broadcast_elemwise
(
input
)
_channels
=
input
.
imm_shape
[
1
]
_ndim
=
len
(
input
.
imm_shape
)
_param_shape
=
(
1
,
_channels
)
+
(
1
,)
*
(
_ndim
-
2
)
if
training
:
def
_sum_on_channel
(
input
):
return
mgb
.
opr
.
reduce_general
([
input
,
_param_shape
],
mode
=
"sum"
)
def
_allreduce
(
stat
,
key
):
return
dist
.
helper
.
collective_comm_symvar
(
stat
,
key
,
CollParam
.
Mode
.
ALL_REDUCE_SUM
)
reduce_size
=
input
.
shape
[
0
]
for
i
in
range
(
2
,
_ndim
):
reduce_size
=
reduce_size
*
input
.
shape
[
i
]
channel_x1s
=
_sum_on_channel
(
input
)
channel_x2s
=
_sum_on_channel
(
input
**
2
)
if
dist
.
is_distributed
():
# reduce all nodes' data to calculate mean and variance
reduce_size
=
reduce_size
.
reshape
(
*
(
1
,)
*
_ndim
)
stat
=
mgb
.
opr
.
concat
([
reduce_size
,
channel_x1s
,
channel_x2s
],
axis
=
1
)
stat
=
_allreduce
(
stat
,
key
=
"sync_bn_"
+
str
(
get_group_id
()))
reduce_size
=
stat
[:,
:
1
].
reshape
(
1
)
channel_x1s
=
stat
[:,
1
:
1
+
_channels
]
channel_x2s
=
stat
[:,
1
+
_channels
:]
channel_mean
=
channel_x1s
/
reduce_size
channel_variance
=
(
channel_x1s
**
2
/
(
-
reduce_size
*
reduce_size
)
+
channel_x2s
/
reduce_size
)
else
:
assert
running_var
is
not
None
and
running_mean
is
not
None
channel_variance
=
running_var
.
reshape
(
*
_param_shape
)
channel_mean
=
running_mean
.
reshape
(
*
_param_shape
)
invsqrt_channel_variance
=
(
mgb
.
opr
.
elem
.
max
(
channel_variance
,
eps
)
if
eps_mode
==
"MAX"
else
mgb
.
opr
.
elem
.
add
(
channel_variance
,
eps
)
)
**
-
0.5
if
weight
is
not
None
:
weight
=
weight
.
reshape
(
*
_param_shape
)
if
bias
is
not
None
:
bias
=
bias
.
reshape
(
*
_param_shape
)
# outvar = output * weight + bias
# where output = input * invsqrt_channel_variance + (
# -channel_mean * invsqrt_channel_variance
# )
# Manually expand output for gopt
if
weight
is
not
None
:
inv_var_wt
=
invsqrt_channel_variance
*
weight
neg_channel_mean
=
-
channel_mean
if
bias
is
not
None
:
outvar
=
input
*
inv_var_wt
+
(
neg_channel_mean
*
inv_var_wt
+
bias
)
else
:
outvar
=
input
*
inv_var_wt
+
neg_channel_mean
*
inv_var_wt
else
:
outvar
=
input
*
invsqrt_channel_variance
+
(
-
channel_mean
*
invsqrt_channel_variance
)
if
bias
is
not
None
:
outvar
=
outvar
+
bias
if
training
and
running_var
is
not
None
and
running_mean
is
not
None
:
_mean_update
=
add_update
(
running_mean
,
channel_mean
,
alpha
=
momentum
,
beta
=
1
-
momentum
,
)
channel_variance_unbiased
=
channel_x1s
**
2
/
(
-
reduce_size
*
(
reduce_size
-
1
)
)
+
channel_x2s
/
(
reduce_size
-
1
)
_variance_update
=
add_update
(
running_var
,
channel_variance_unbiased
,
alpha
=
momentum
,
beta
=
1
-
momentum
)
for
dep
in
(
_mean_update
,
_variance_update
):
add_extra_vardep
(
outvar
,
dep
)
return
outvar
def
one_hot
(
inp
:
Tensor
,
num_classes
:
int
)
->
Tensor
:
r
"""
Perform one-hot encoding for the input tensor.
...
...
python_module/megengine/module/__init__.py
浏览文件 @
40412e26
...
...
@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.activation
import
LeakyReLU
,
PReLU
,
ReLU
,
Sigmoid
,
Softmax
from
.batchnorm
import
BatchNorm1d
,
BatchNorm2d
from
.batchnorm
import
BatchNorm1d
,
BatchNorm2d
,
SyncBatchNorm
from
.concat
import
Concat
from
.conv
import
Conv2d
,
ConvTranspose2d
,
LocalConv2d
from
.conv_bn_relu
import
ConvBn2d
,
ConvBnRelu2d
...
...
python_module/megengine/module/batchnorm.py
浏览文件 @
40412e26
...
...
@@ -9,7 +9,7 @@
import
numpy
as
np
from
..core
import
Buffer
,
Parameter
from
..functional
import
batch_norm2d
from
..functional
import
batch_norm2d
,
sync_batch_norm
from
.
import
init
from
.module
import
Module
...
...
@@ -74,7 +74,6 @@ class _BatchNorm(Module):
inp
=
inp
.
reshape
(
new_shape
)
_iter_update
=
None
if
self
.
training
and
self
.
track_running_stats
:
exponential_average_factor
=
self
.
momentum
else
:
...
...
@@ -97,6 +96,54 @@ class _BatchNorm(Module):
return
output
class
SyncBatchNorm
(
_BatchNorm
):
r
"""
Applies Synchronization Batch Normalization.
"""
def
_check_input_ndim
(
self
,
inp
):
if
len
(
inp
.
shape
)
not
in
{
2
,
3
,
4
}:
raise
ValueError
(
"expected 2D, 3D or 4D input (got {}D input)"
.
format
(
len
(
inp
.
shape
))
)
def
forward
(
self
,
inp
):
self
.
_check_input_ndim
(
inp
)
_ndims
=
len
(
inp
.
shape
)
if
_ndims
!=
4
:
origin_shape
=
inp
.
shapeof
()
if
_ndims
==
2
:
n
,
c
=
inp
.
shapeof
(
0
),
inp
.
shapeof
(
1
)
new_shape
=
(
n
,
c
,
1
,
1
)
elif
_ndims
==
3
:
n
,
c
,
h
=
inp
.
shapeof
(
0
),
inp
.
shapeof
(
1
),
inp
.
shapeof
(
2
)
new_shape
=
(
n
,
c
,
h
,
1
)
inp
=
inp
.
reshape
(
new_shape
)
if
self
.
training
and
self
.
track_running_stats
:
exponential_average_factor
=
self
.
momentum
else
:
exponential_average_factor
=
0.0
# useless
output
=
sync_batch_norm
(
inp
,
self
.
running_mean
,
self
.
running_var
,
self
.
weight
,
self
.
bias
,
self
.
training
or
not
self
.
track_running_stats
,
exponential_average_factor
,
self
.
eps
,
)
if
_ndims
!=
4
:
output
=
output
.
reshape
(
origin_shape
)
return
output
class
BatchNorm1d
(
_BatchNorm
):
r
"""
Applies Batch Normalization over a 2D/3D tensor.
...
...
python_module/megengine/optimizer/optimizer.py
浏览文件 @
40412e26
...
...
@@ -18,6 +18,7 @@ from .._internal.config import opr_priority_scope
from
..core
import
Buffer
,
Parameter
,
Tensor
,
TensorDict
from
..core.graph
import
get_default_graph
from
..distributed
import
all_reduce_sum
,
bcast_param
,
get_world_size
,
is_distributed
from
..distributed.util
import
get_group_id
from
..functional
import
add_update
from
..functional
import
grad
as
grad_func
from
..jit
import
sideeffect
...
...
@@ -152,7 +153,7 @@ class Optimizer(metaclass=ABCMeta):
:param loss: The obtained loss tensor
"""
rst
=
[]
ke
y
=
0
priorit
y
=
0
params
=
[]
for
group
in
self
.
param_groups
:
for
param
in
group
[
"params"
]:
...
...
@@ -173,11 +174,14 @@ class Optimizer(metaclass=ABCMeta):
for
param
,
grad
in
zip
(
params
,
grads
):
if
is_distributed
():
ke
y
+=
1
with
opr_priority_scope
(
cg
,
-
ke
y
):
priorit
y
+=
1
with
opr_priority_scope
(
cg
,
-
priorit
y
):
# all_reduce_mean
grad
=
all_reduce_sum
(
grad
,
key
)
/
get_world_size
()
with
opr_priority_scope
(
cg
,
(
1
<<
30
)
-
key
):
grad
=
(
all_reduce_sum
(
grad
,
"grad_"
+
str
(
get_group_id
()))
/
get_world_size
()
)
with
opr_priority_scope
(
cg
,
(
1
<<
30
)
-
priority
):
grad_update
=
add_update
(
param
.
grad
,
grad
)
else
:
grad_update
=
add_update
(
param
.
grad
,
grad
)
...
...
@@ -216,11 +220,9 @@ class Optimizer(metaclass=ABCMeta):
param
.
grad
.
reset_zero
()
def
bcast_param
(
self
):
key
=
0
for
group
in
self
.
param_groups
:
for
param
in
group
[
"params"
]:
bcast_param
(
param
,
key
)
key
+=
1
bcast_param
(
param
,
"bcast_param_"
+
str
(
get_group_id
()))
def
state_dict
(
self
)
->
Dict
:
r
"""Export the optimizer state.
...
...
python_module/test/unit/module/test_batchnorm.py
浏览文件 @
40412e26
...
...
@@ -6,15 +6,86 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
multiprocessing
as
mp
import
numpy
as
np
import
pytest
import
megengine
as
mge
import
megengine.distributed
as
dist
from
megengine.core
import
tensor
from
megengine.module
import
BatchNorm1d
,
BatchNorm2d
from
megengine.module
import
BatchNorm1d
,
BatchNorm2d
,
SyncBatchNorm
from
megengine.test
import
assertTensorClose
@
pytest
.
mark
.
isolated_distributed
def
test_syncbn
():
nr_chan
=
8
data_shape
=
(
3
,
nr_chan
,
4
,
16
)
momentum
=
0.9
eps
=
1e-5
running_mean
=
np
.
zeros
((
1
,
nr_chan
,
1
,
1
),
dtype
=
np
.
float32
)
running_var
=
np
.
ones
((
1
,
nr_chan
,
1
,
1
),
dtype
=
np
.
float32
)
steps
=
4
def
worker
(
rank
,
data
,
yv_expect
,
running_mean
,
running_var
):
if
not
mge
.
is_cuda_available
():
return
dist
.
init_process_group
(
"localhost"
,
2333
,
4
,
rank
,
rank
)
bn
=
SyncBatchNorm
(
nr_chan
,
momentum
=
momentum
,
eps
=
eps
)
data_tensor
=
tensor
()
for
i
in
range
(
steps
):
data_tensor
.
set_value
(
data
[
i
])
yv
=
bn
(
data_tensor
)
assertTensorClose
(
yv_expect
,
yv
.
numpy
(),
max_err
=
5e-6
)
assertTensorClose
(
running_mean
,
bn
.
running_mean
.
numpy
(),
max_err
=
5e-6
)
assertTensorClose
(
running_var
,
bn
.
running_var
.
numpy
(),
max_err
=
5e-6
)
xv
=
[]
for
i
in
range
(
steps
):
xv
.
append
(
np
.
random
.
normal
(
loc
=
2.3
,
size
=
data_shape
).
astype
(
np
.
float32
))
xv_transposed
=
np
.
transpose
(
xv
[
i
],
[
0
,
2
,
3
,
1
]).
reshape
(
(
data_shape
[
0
]
*
data_shape
[
2
]
*
data_shape
[
3
],
nr_chan
)
)
mean
=
np
.
mean
(
xv_transposed
,
axis
=
0
).
reshape
(
1
,
nr_chan
,
1
,
1
)
var_biased
=
np
.
var
(
xv_transposed
,
axis
=
0
).
reshape
((
1
,
nr_chan
,
1
,
1
))
sd
=
np
.
sqrt
(
var_biased
+
eps
)
var_unbiased
=
np
.
var
(
xv_transposed
,
axis
=
0
,
ddof
=
1
).
reshape
((
1
,
nr_chan
,
1
,
1
))
running_mean
=
running_mean
*
momentum
+
mean
*
(
1
-
momentum
)
running_var
=
running_var
*
momentum
+
var_unbiased
*
(
1
-
momentum
)
yv_expect
=
(
xv
[
i
]
-
mean
)
/
sd
data
=
[]
for
i
in
range
(
4
):
data
.
append
([])
for
j
in
range
(
steps
):
data
[
i
].
append
(
xv
[
j
][:,
:,
:,
i
*
4
:
i
*
4
+
4
])
procs
=
[]
for
rank
in
range
(
4
):
p
=
mp
.
Process
(
target
=
worker
,
args
=
(
rank
,
data
[
rank
],
yv_expect
[:,
:,
:,
rank
*
4
:
rank
*
4
+
4
],
running_mean
,
running_var
,
),
)
p
.
start
()
procs
.
append
(
p
)
for
p
in
procs
:
p
.
join
()
assert
p
.
exitcode
==
0
def
test_batchnorm
():
nr_chan
=
8
data_shape
=
(
3
,
nr_chan
,
4
)
...
...
@@ -64,6 +135,55 @@ def test_batchnorm():
assertTensorClose
(
yv_expect
,
yv1
.
numpy
(),
max_err
=
5e-6
)
def
test_syncbn1d
():
nr_chan
=
8
data_shape
=
(
3
,
nr_chan
,
4
)
momentum
=
0.9
bn
=
SyncBatchNorm
(
nr_chan
,
momentum
=
momentum
)
running_mean
=
np
.
zeros
((
1
,
nr_chan
,
1
),
dtype
=
np
.
float32
)
running_var
=
np
.
ones
((
1
,
nr_chan
,
1
),
dtype
=
np
.
float32
)
data
=
tensor
()
for
i
in
range
(
3
):
xv
=
np
.
random
.
normal
(
loc
=
2.3
,
size
=
data_shape
).
astype
(
np
.
float32
)
mean
=
np
.
mean
(
np
.
mean
(
xv
,
axis
=
0
,
keepdims
=
True
),
axis
=
2
,
keepdims
=
True
)
xv_transposed
=
np
.
transpose
(
xv
,
[
0
,
2
,
1
]).
reshape
(
(
data_shape
[
0
]
*
data_shape
[
2
],
nr_chan
)
)
var_biased
=
np
.
var
(
xv_transposed
,
axis
=
0
).
reshape
((
1
,
nr_chan
,
1
))
sd
=
np
.
sqrt
(
var_biased
+
bn
.
eps
)
var_unbiased
=
np
.
var
(
xv_transposed
,
axis
=
0
,
ddof
=
1
).
reshape
((
1
,
nr_chan
,
1
))
running_mean
=
running_mean
*
momentum
+
mean
*
(
1
-
momentum
)
running_var
=
running_var
*
momentum
+
var_unbiased
*
(
1
-
momentum
)
data
.
set_value
(
xv
)
yv
=
bn
(
data
)
yv_expect
=
(
xv
-
mean
)
/
sd
assertTensorClose
(
yv_expect
,
yv
.
numpy
(),
max_err
=
5e-6
)
assertTensorClose
(
running_mean
.
reshape
(
-
1
),
bn
.
running_mean
.
numpy
().
reshape
(
-
1
),
max_err
=
5e-6
)
assertTensorClose
(
running_var
.
reshape
(
-
1
),
bn
.
running_var
.
numpy
().
reshape
(
-
1
),
max_err
=
5e-6
)
# test set 'training' flag to False
mean_backup
=
bn
.
running_mean
.
numpy
()
var_backup
=
bn
.
running_var
.
numpy
()
bn
.
training
=
False
xv
=
np
.
random
.
normal
(
loc
=
2.3
,
size
=
data_shape
).
astype
(
np
.
float32
)
data
.
set_value
(
xv
)
yv1
=
bn
(
data
)
yv2
=
bn
(
data
)
assertTensorClose
(
yv1
.
numpy
(),
yv2
.
numpy
(),
max_err
=
0
)
assertTensorClose
(
mean_backup
,
bn
.
running_mean
.
numpy
(),
max_err
=
0
)
assertTensorClose
(
var_backup
,
bn
.
running_var
.
numpy
(),
max_err
=
0
)
yv_expect
=
(
xv
-
running_mean
)
/
np
.
sqrt
(
running_var
+
bn
.
eps
)
assertTensorClose
(
yv_expect
,
yv1
.
numpy
(),
max_err
=
5e-6
)
def
test_batchnorm2d
():
nr_chan
=
8
data_shape
=
(
3
,
nr_chan
,
16
,
16
)
...
...
@@ -110,6 +230,52 @@ def test_batchnorm2d():
assertTensorClose
(
yv_expect
,
yv1
.
numpy
(),
max_err
=
5e-6
)
def
test_syncbn2d
():
nr_chan
=
8
data_shape
=
(
3
,
nr_chan
,
16
,
16
)
momentum
=
0.9
bn
=
SyncBatchNorm
(
nr_chan
,
momentum
=
momentum
)
running_mean
=
np
.
zeros
((
1
,
nr_chan
,
1
,
1
),
dtype
=
np
.
float32
)
running_var
=
np
.
ones
((
1
,
nr_chan
,
1
,
1
),
dtype
=
np
.
float32
)
data
=
tensor
()
for
i
in
range
(
3
):
xv
=
np
.
random
.
normal
(
loc
=
2.3
,
size
=
data_shape
).
astype
(
np
.
float32
)
xv_transposed
=
np
.
transpose
(
xv
,
[
0
,
2
,
3
,
1
]).
reshape
(
(
data_shape
[
0
]
*
data_shape
[
2
]
*
data_shape
[
3
],
nr_chan
)
)
mean
=
np
.
mean
(
xv_transposed
,
axis
=
0
).
reshape
(
1
,
nr_chan
,
1
,
1
)
var_biased
=
np
.
var
(
xv_transposed
,
axis
=
0
).
reshape
((
1
,
nr_chan
,
1
,
1
))
sd
=
np
.
sqrt
(
var_biased
+
bn
.
eps
)
var_unbiased
=
np
.
var
(
xv_transposed
,
axis
=
0
,
ddof
=
1
).
reshape
((
1
,
nr_chan
,
1
,
1
))
running_mean
=
running_mean
*
momentum
+
mean
*
(
1
-
momentum
)
running_var
=
running_var
*
momentum
+
var_unbiased
*
(
1
-
momentum
)
data
.
set_value
(
xv
)
yv
=
bn
(
data
)
yv_expect
=
(
xv
-
mean
)
/
sd
assertTensorClose
(
yv_expect
,
yv
.
numpy
(),
max_err
=
5e-6
)
assertTensorClose
(
running_mean
,
bn
.
running_mean
.
numpy
(),
max_err
=
5e-6
)
assertTensorClose
(
running_var
,
bn
.
running_var
.
numpy
(),
max_err
=
5e-6
)
# test set 'training' flag to False
mean_backup
=
bn
.
running_mean
.
numpy
()
var_backup
=
bn
.
running_var
.
numpy
()
bn
.
training
=
False
xv
=
np
.
random
.
normal
(
loc
=
2.3
,
size
=
data_shape
).
astype
(
np
.
float32
)
data
.
set_value
(
xv
)
yv1
=
bn
(
data
)
yv2
=
bn
(
data
)
assertTensorClose
(
yv1
.
numpy
(),
yv2
.
numpy
(),
max_err
=
0
)
assertTensorClose
(
mean_backup
,
bn
.
running_mean
.
numpy
(),
max_err
=
0
)
assertTensorClose
(
var_backup
,
bn
.
running_var
.
numpy
(),
max_err
=
0
)
yv_expect
=
(
xv
-
running_mean
)
/
np
.
sqrt
(
running_var
+
bn
.
eps
)
assertTensorClose
(
yv_expect
,
yv1
.
numpy
(),
max_err
=
5e-6
)
def
test_batchnorm_no_stats
():
nr_chan
=
8
data_shape
=
(
3
,
nr_chan
,
4
)
...
...
@@ -135,6 +301,31 @@ def test_batchnorm_no_stats():
assertTensorClose
(
yv_expect
,
yv
.
numpy
(),
max_err
=
5e-6
)
def
test_syncbn_no_stats
():
nr_chan
=
8
data_shape
=
(
3
,
nr_chan
,
4
)
bn
=
SyncBatchNorm
(
8
,
track_running_stats
=
False
)
data
=
tensor
()
for
i
in
range
(
4
):
if
i
==
2
:
bn
.
training
=
False
xv
=
np
.
random
.
normal
(
loc
=
2.3
,
size
=
data_shape
).
astype
(
np
.
float32
)
mean
=
np
.
mean
(
np
.
mean
(
xv
,
axis
=
0
,
keepdims
=
True
),
axis
=
2
,
keepdims
=
True
)
var
=
np
.
var
(
np
.
transpose
(
xv
,
[
0
,
2
,
1
]).
reshape
(
(
data_shape
[
0
]
*
data_shape
[
2
],
nr_chan
)
),
axis
=
0
,
).
reshape
((
1
,
nr_chan
,
1
))
sd
=
np
.
sqrt
(
var
+
bn
.
eps
)
data
.
set_value
(
xv
)
yv
=
bn
(
data
)
yv_expect
=
(
xv
-
mean
)
/
sd
assertTensorClose
(
yv_expect
,
yv
.
numpy
(),
max_err
=
5e-6
)
def
test_batchnorm2d_no_stats
():
nr_chan
=
8
data_shape
=
(
3
,
nr_chan
,
16
,
16
)
...
...
@@ -157,3 +348,27 @@ def test_batchnorm2d_no_stats():
yv_expect
=
(
xv
-
mean
)
/
sd
assertTensorClose
(
yv_expect
,
yv
.
numpy
(),
max_err
=
5e-6
)
def
test_syncbn2d_no_stats
():
nr_chan
=
8
data_shape
=
(
3
,
nr_chan
,
16
,
16
)
bn
=
SyncBatchNorm
(
8
,
track_running_stats
=
False
)
data
=
tensor
()
for
i
in
range
(
4
):
if
i
==
2
:
bn
.
training
=
False
xv
=
np
.
random
.
normal
(
loc
=
2.3
,
size
=
data_shape
).
astype
(
np
.
float32
)
xv_transposed
=
np
.
transpose
(
xv
,
[
0
,
2
,
3
,
1
]).
reshape
(
(
data_shape
[
0
]
*
data_shape
[
2
]
*
data_shape
[
3
],
nr_chan
)
)
mean
=
np
.
mean
(
xv_transposed
,
axis
=
0
).
reshape
(
1
,
nr_chan
,
1
,
1
)
var
=
np
.
var
(
xv_transposed
,
axis
=
0
).
reshape
((
1
,
nr_chan
,
1
,
1
))
sd
=
np
.
sqrt
(
var
+
bn
.
eps
)
data
.
set_value
(
xv
)
yv
=
bn
(
data
)
yv_expect
=
(
xv
-
mean
)
/
sd
assertTensorClose
(
yv_expect
,
yv
.
numpy
(),
max_err
=
5e-6
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录