未验证 提交 2e8ad253 编写于 作者: C Chang Xu 提交者: GitHub

add_sync (#702)

* add_syncbn
上级 e7a02b5c
...@@ -29,7 +29,7 @@ if pd_ver == 185: ...@@ -29,7 +29,7 @@ if pd_ver == 185:
Layer = paddle.fluid.dygraph.Layer Layer = paddle.fluid.dygraph.Layer
else: else:
import paddle.nn as nn import paddle.nn as nn
from paddle.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding from paddle.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding, SyncBatchNorm
from paddle import ParamAttr from paddle import ParamAttr
from .layers import * from .layers import *
from . import layers from . import layers
...@@ -287,6 +287,44 @@ class Convert: ...@@ -287,6 +287,44 @@ class Convert:
) if pd_ver == 185 else layers.SuperBatchNorm2D(**new_attr_dict) ) if pd_ver == 185 else layers.SuperBatchNorm2D(**new_attr_dict)
model[idx] = layer model[idx] = layer
elif isinstance(layer, SyncBatchNorm) and (
getattr(self.context, 'expand', None) != None or
getattr(self.context, 'channel', None) != None):
# num_features in SyncBatchNorm don't change after last weight operators
if idx > last_weight_layer_idx:
continue
attr_dict = layer.__dict__
new_attr_name = ['momentum', 'epsilon', 'bias_attr']
new_attr_name += ['weight_attr', 'data_format', 'name']
self._change_name(layer, pd_ver)
new_attr_dict = dict.fromkeys(new_attr_name, None)
new_attr_dict['num_features'] = None
new_key = 'num_channels' if 'num_channels' in new_attr_dict.keys(
) else 'num_features'
if self.context.expand:
new_attr_dict[new_key] = int(
self.context.expand *
layer._parameters['weight'].shape[0])
elif self.context.channel:
new_attr_dict[new_key] = max(cur_channel)
else:
new_attr_dict[new_key] = attr_dict[
'_num_channels'] if '_num_channels' in attr_dict.keys(
) else attr_dict['_num_features']
for attr in new_attr_name:
new_attr_dict[attr] = attr_dict['_' + attr]
del layer, attr_dict
layer = layers.SuperSyncBatchNorm(**new_attr_dict)
model[idx] = layer
### assume output_size = None, filter_size != None ### assume output_size = None, filter_size != None
### NOTE: output_size != None may raise error, solve when it happend. ### NOTE: output_size != None may raise error, solve when it happend.
elif isinstance(layer, Conv2DTranspose): elif isinstance(layer, Conv2DTranspose):
...@@ -651,13 +689,23 @@ class Convert: ...@@ -651,13 +689,23 @@ class Convert:
raise NotImplementedError("name error") raise NotImplementedError("name error")
return net return net
def get_split_names(layer, name_list):
if name_list:
self.name_list.append(name_list)
for _, (name, sublayer) in enumerate(layer.named_children()):
if sublayer.named_children():
get_split_names(sublayer, name_list + [name])
if isinstance(network, Layer): if isinstance(network, Layer):
for idx, (name, sublayer) in enumerate(network.named_sublayers()): curr_id = 0
if len(name.split('.')) > 1: self.name_list = []
net = split_prefix(network, name.split('.')[:-1]) get_split_names(network, [])
for idx, nl in enumerate(self.name_list):
if len(nl) > 1:
net = split_prefix(network, nl[:-1])
else: else:
net = network net = network
setattr(net, name.split('.')[-1], model[idx]) setattr(net, nl[-1], model[idx])
return network return network
......
...@@ -29,7 +29,8 @@ __all__ = [ ...@@ -29,7 +29,8 @@ __all__ = [
'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D', 'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D',
'SuperBatchNorm2D', 'SuperLinear', 'SuperInstanceNorm2D', 'SuperBatchNorm2D', 'SuperLinear', 'SuperInstanceNorm2D',
'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose', 'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose',
'SuperDepthwiseConv2DTranspose', 'SuperLayerNorm', 'SuperEmbedding' 'SuperDepthwiseConv2DTranspose', 'SuperLayerNorm', 'SuperEmbedding',
'SuperSyncBatchNorm'
] ]
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -956,6 +957,42 @@ class SuperBatchNorm2D(nn.BatchNorm2D): ...@@ -956,6 +957,42 @@ class SuperBatchNorm2D(nn.BatchNorm2D):
data_format=self._data_format) data_format=self._data_format)
class SuperSyncBatchNorm(nn.SyncBatchNorm):
def __init__(self,
num_features,
momentum=0.9,
epsilon=1e-05,
weight_attr=None,
bias_attr=None,
data_format='NCHW',
name=None):
super(SuperSyncBatchNorm,
self).__init__(num_features, momentum, epsilon, weight_attr,
bias_attr, data_format, name)
def forward(self, input):
feature_dim = int(input.shape[1])
weight = self.weight[:feature_dim]
bias = self.bias[:feature_dim]
mean = self._mean[:feature_dim]
variance = self._variance[:feature_dim]
mean_out = mean
# variance and variance out share the same memory
variance_out = variance
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout", self._data_format,
"use_mkldnn", False, "fuse_with_relu", False,
"use_global_stats", False, 'trainable_statistics', False)
sync_batch_norm_out, _, _, _, _, _ = core.ops.sync_batch_norm(
input, weight, bias, mean, variance, mean_out, variance_out, *attrs)
return sync_batch_norm_out
class SuperInstanceNorm2D(nn.InstanceNorm2D): class SuperInstanceNorm2D(nn.InstanceNorm2D):
""" """
This interface is used to construct a callable object of the ``SuperBatchNorm2D`` class. This interface is used to construct a callable object of the ``SuperBatchNorm2D`` class.
......
...@@ -37,6 +37,7 @@ class ModelConv(nn.Layer): ...@@ -37,6 +37,7 @@ class ModelConv(nn.Layer):
models = [] models = []
models += [nn.Conv2D(3, 4, 3, padding=1)] models += [nn.Conv2D(3, 4, 3, padding=1)]
models += [nn.InstanceNorm2D(4)] models += [nn.InstanceNorm2D(4)]
models += [nn.SyncBatchNorm(4)]
models += [ReLU()] models += [ReLU()]
models += [nn.Conv2D(4, 4, 3, groups=4)] models += [nn.Conv2D(4, 4, 3, groups=4)]
models += [nn.InstanceNorm2D(4)] models += [nn.InstanceNorm2D(4)]
......
...@@ -30,6 +30,7 @@ class ModelCase1(nn.Layer): ...@@ -30,6 +30,7 @@ class ModelCase1(nn.Layer):
super(ModelCase1, self).__init__() super(ModelCase1, self).__init__()
models = [SuperConv2D(3, 4, 3, bias_attr=False)] models = [SuperConv2D(3, 4, 3, bias_attr=False)]
models += [SuperConv2D(4, 4, 3, groups=4)] models += [SuperConv2D(4, 4, 3, groups=4)]
models += [SuperSyncBatchNorm(4)]
models += [SuperConv2D(4, 4, 3, groups=2)] models += [SuperConv2D(4, 4, 3, groups=2)]
models += [SuperConv2DTranspose(4, 4, 3, bias_attr=False)] models += [SuperConv2DTranspose(4, 4, 3, bias_attr=False)]
models += [SuperConv2DTranspose(4, 4, 3, groups=4)] models += [SuperConv2DTranspose(4, 4, 3, groups=4)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册