未验证 提交 40e0684a 编写于 作者: C ceci3 提交者: GitHub

Refine ofa (#527)

* support 2.0
上级 f3b898c8
......@@ -14,4 +14,10 @@
from .ofa import OFA, RunConfig, DistillConfig
from .convert_super import supernet
from .layers import *
from .utils.utils import get_paddle_version
pd_ver = get_paddle_version()
if pd_ver == 185:
from .layers import *
else:
from .layers_new import *
......@@ -15,11 +15,23 @@
import inspect
import decorator
import logging
import paddle
import numbers
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, BatchNorm, InstanceNorm, LayerNorm, Embedding
from .layers import *
import paddle
from ...common import get_logger
from .utils.utils import get_paddle_version
pd_ver = get_paddle_version()
if pd_ver == 185:
import paddle.fluid.dygraph.nn as nn
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding
from .layers import *
from . import layers
Layer = fluid.dygraph.nn.Layer
else:
import paddle.nn as nn
from paddle.nn import Conv2D, Conv2DTranspose, Linear, LayerNorm, Embedding
from .layers_new import *
from . import layers_new as layers
Layer = paddle.nn.Layer
_logger = get_logger(__name__, level=logging.INFO)
......@@ -28,19 +40,25 @@ __all__ = ['supernet']
WEIGHT_LAYER = ['conv', 'linear', 'embedding']
### TODO: add decorator
class Convert:
def __init__(self, context):
self.context = context
def convert(self, model):
def convert(self, network):
# search the first and last weight layer, don't change out channel of the last weight layer
# don't change in channel of the first weight layer
model = []
if isinstance(network, Layer):
for name, sublayer in network.named_sublayers():
model.append(sublayer)
else:
model = network
first_weight_layer_idx = -1
last_weight_layer_idx = -1
weight_layer_count = 0
# NOTE: pre_channel store for shortcut module
pre_channel = 0
pre_channel = None
cur_channel = None
for idx, layer in enumerate(model):
cls_name = layer.__class__.__name__.lower()
......@@ -61,50 +79,68 @@ class Convert:
key = attr_dict['_full_name']
new_attr_name = [
'_stride', '_dilation', '_groups', '_param_attr',
'_bias_attr', '_use_cudnn', '_act', '_dtype', '_padding'
'stride', 'padding', 'dilation', 'groups', 'bias_attr'
]
if pd_ver == 185:
new_attr_name += ['param_attr', 'use_cudnn', 'act', 'dtype']
else:
new_attr_name += [
'weight_attr', 'data_format', 'padding_mode'
]
new_attr_dict = dict()
new_attr_dict = dict.fromkeys(new_attr_name, None)
new_attr_dict['candidate_config'] = dict()
if pd_ver == 185:
new_attr_dict['num_channels'] = None
new_attr_dict['num_filters'] = None
new_attr_dict['filter_size'] = None
else:
new_attr_dict['in_channels'] = None
new_attr_dict['out_channels'] = None
new_attr_dict['kernel_size'] = None
self.kernel_size = getattr(self.context, 'kernel_size', None)
if self.kernel_size != None:
new_attr_dict['transform_kernel'] = True
# if the kernel_size of conv is 1, don't change it.
#if self.kernel_size and int(attr_dict['_filter_size'][0]) != 1:
if self.kernel_size and int(attr_dict['_filter_size']) != 1:
new_attr_dict['filter_size'] = max(self.kernel_size)
fks = '_filter_size' if '_filter_size' in attr_dict.keys(
) else '_kernel_size'
ks = list(attr_dict[fks]) if isinstance(
attr_dict[fks], numbers.Integral) else attr_dict[fks]
if self.kernel_size and int(ks[0]) != 1:
new_attr_dict['transform_kernel'] = True
new_attr_dict[fks[1:]] = max(self.kernel_size)
new_attr_dict['candidate_config'].update({
'kernel_size': self.kernel_size
})
else:
new_attr_dict['filter_size'] = attr_dict['_filter_size']
new_attr_dict[fks[1:]] = attr_dict[fks]
in_key = '_num_channels' if '_num_channels' in attr_dict.keys(
) else '_in_channels'
out_key = '_num_filters' if '_num_filters' in attr_dict.keys(
) else '_out_channels'
if self.context.expand:
### first super convolution
if idx == first_weight_layer_idx:
new_attr_dict['num_channels'] = attr_dict[
'_num_channels']
new_attr_dict[in_key[1:]] = attr_dict[in_key]
else:
new_attr_dict[
'num_channels'] = self.context.expand * attr_dict[
'_num_channels']
new_attr_dict[in_key[1:]] = int(self.context.expand *
attr_dict[in_key])
### last super convolution
if idx == last_weight_layer_idx:
new_attr_dict['num_filters'] = attr_dict['_num_filters']
new_attr_dict[out_key[1:]] = attr_dict[out_key]
else:
new_attr_dict[
'num_filters'] = self.context.expand * attr_dict[
'_num_filters']
new_attr_dict[out_key[1:]] = int(self.context.expand *
attr_dict[out_key])
new_attr_dict['candidate_config'].update({
'expand_ratio': self.context.expand_ratio
})
elif self.context.channel:
if attr_dict['_groups'] != None and (
int(attr_dict['_groups']) ==
int(attr_dict['_num_channels'])):
int(attr_dict['_groups']) == int(attr_dict[in_key])
):
### depthwise conv, if conv is depthwise, use pre channel as cur_channel
_logger.warn(
"If convolution is a depthwise conv, output channel change" \
......@@ -115,25 +151,27 @@ class Convert:
cur_channel = self.context.channel[0]
self.context.channel = self.context.channel[1:]
if idx == first_weight_layer_idx:
new_attr_dict['num_channels'] = attr_dict[
'_num_channels']
new_attr_dict[in_key[1:]] = attr_dict[in_key]
else:
new_attr_dict['num_channels'] = max(pre_channel)
new_attr_dict[in_key[1:]] = max(pre_channel)
if idx == last_weight_layer_idx:
new_attr_dict['num_filters'] = attr_dict['_num_filters']
new_attr_dict[out_key[1:]] = attr_dict[out_key]
else:
new_attr_dict['num_filters'] = max(cur_channel)
new_attr_dict[out_key[1:]] = max(cur_channel)
new_attr_dict['candidate_config'].update({
'channel': cur_channel
})
pre_channel = cur_channel
else:
new_attr_dict['num_filters'] = attr_dict['_num_filters']
new_attr_dict['num_channels'] = attr_dict['_num_channels']
new_attr_dict[in_key[1:]] = attr_dict[in_key]
new_attr_dict[out_key[1:]] = attr_dict[out_key]
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
if attr == 'weight_attr':
new_attr_dict[attr] = attr_dict['_param_attr']
else:
new_attr_dict[attr] = attr_dict['_' + attr]
del layer
......@@ -141,17 +179,15 @@ class Convert:
'_groups']) == 1:
### standard conv
layer = Block(SuperConv2D(**new_attr_dict), key=key)
elif int(attr_dict['_groups']) == int(attr_dict[
'_num_channels']):
elif int(attr_dict['_groups']) == int(attr_dict[in_key]):
# if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
# channel in candidate_config = in_channel_list
if 'channel' in new_attr_dict['candidate_config']:
new_attr_dict['num_channels'] = max(cur_channel)
new_attr_dict['num_filters'] = new_attr_dict[
'num_channels']
new_attr_dict[in_key[1:]] = max(cur_channel)
new_attr_dict[out_key[1:]] = new_attr_dict[in_key[1:]]
new_attr_dict['candidate_config'][
'channel'] = cur_channel
new_attr_dict['groups'] = new_attr_dict['num_channels']
new_attr_dict['groups'] = new_attr_dict[in_key[1:]]
layer = Block(
SuperDepthwiseConv2D(**new_attr_dict), key=key)
else:
......@@ -159,7 +195,8 @@ class Convert:
layer = Block(SuperGroupConv2D(**new_attr_dict), key=key)
model[idx] = layer
elif isinstance(layer, BatchNorm) and (
elif isinstance(layer,
getattr(nn, 'BatchNorm2D', nn.BatchNorm)) and (
getattr(self.context, 'expand', None) != None or
getattr(self.context, 'channel', None) != None):
# num_features in BatchNorm don't change after last weight operators
......@@ -167,26 +204,41 @@ class Convert:
continue
attr_dict = layer.__dict__
new_attr_name = [
'_param_attr', '_bias_attr', '_act', '_dtype', '_in_place',
'_data_layout', '_momentum', '_epsilon', '_is_test',
'_use_global_stats', '_trainable_statistics'
new_attr_name = ['momentum', 'epsilon', 'bias_attr']
if pd_ver == 185:
new_attr_name += [
'param_attr', 'act', 'dtype', 'in_place', 'data_layout',
'is_test', 'use_global_stats', 'trainable_statistics'
]
new_attr_dict = dict()
else:
new_attr_name += ['weight_attr', 'data_format', 'name']
new_attr_dict = dict.fromkeys(new_attr_name, None)
if pd_ver == 185:
new_attr_dict['num_channels'] = None
else:
new_attr_dict['num_features'] = None
new_key = 'num_channels' if 'num_channels' in new_attr_dict.keys(
) else 'num_features'
if self.context.expand:
new_attr_dict['num_channels'] = self.context.expand * int(
new_attr_dict[new_key] = int(
self.context.expand *
layer._parameters['weight'].shape[0])
elif self.context.channel:
new_attr_dict['num_channels'] = max(cur_channel)
new_attr_dict[new_key] = max(cur_channel)
else:
new_attr_dict['num_channels'] = attr_dict['_num_channels']
new_attr_dict[new_key] = attr_dict[
'_num_channels'] if '_num_channels' in attr_dict.keys(
) else attr_dict['_num_features']
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
new_attr_dict[attr] = attr_dict['_' + attr]
del layer, attr_dict
layer = SuperBatchNorm(**new_attr_dict)
layer = getattr(layers, 'SuperBatchNorm', SuperBatchNorm2D)(
**new_attr_dict)
model[idx] = layer
### assume output_size = None, filter_size != None
......@@ -196,52 +248,72 @@ class Convert:
key = attr_dict['_full_name']
new_attr_name = [
'_stride', '_dilation', '_groups', '_param_attr',
'_padding', '_bias_attr', '_use_cudnn', '_act', '_dtype',
'_output_size'
'stride', 'padding', 'dilation', 'groups', 'bias_attr'
]
assert attr_dict[
'_filter_size'] != None, "Conv2DTranspose only support filter size != None now"
assert getattr(
attr_dict, '_filter_size', '_kernel_size'
) != None, "Conv2DTranspose only support kernel size != None now"
new_attr_dict = dict()
if pd_ver == 185:
new_attr_name += [
'output_size', 'param_attr', 'use_cudnn', 'act', 'dtype'
]
else:
new_attr_name += [
'output_padding', 'weight_attr', 'data_format'
]
new_attr_dict = dict.fromkeys(new_attr_name, None)
new_attr_dict['candidate_config'] = dict()
if pd_ver == 185:
new_attr_dict['num_channels'] = None
new_attr_dict['num_filters'] = None
new_attr_dict['filter_size'] = None
else:
new_attr_dict['in_channels'] = None
new_attr_dict['out_channels'] = None
new_attr_dict['kernel_size'] = None
self.kernel_size = getattr(self.context, 'kernel_size', None)
if self.kernel_size != None:
new_attr_dict['transform_kernel'] = True
# if the kernel_size of conv transpose is 1, don't change it.
if self.kernel_size and int(attr_dict['_filter_size'][0]) != 1:
new_attr_dict['filter_size'] = max(self.kernel_size)
fks = '_filter_size' if '_filter_size' in attr_dict.keys(
) else '_kernel_size'
ks = list(attr_dict[fks]) if isinstance(
attr_dict[fks], numbers.Integral) else attr_dict[fks]
if self.kernel_size and int(ks[0]) != 1:
new_attr_dict['transform_kernel'] = True
new_attr_dict[fks[1:]] = max(self.kernel_size)
new_attr_dict['candidate_config'].update({
'kernel_size': self.kernel_size
})
else:
new_attr_dict['filter_size'] = attr_dict['_filter_size']
new_attr_dict[fks[1:]] = attr_dict[fks]
in_key = '_num_channels' if '_num_channels' in attr_dict.keys(
) else '_in_channels'
out_key = '_num_filters' if '_num_filters' in attr_dict.keys(
) else '_out_channels'
if self.context.expand:
### first super convolution transpose
if idx == first_weight_layer_idx:
new_attr_dict['num_channels'] = attr_dict[
'_num_channels']
new_attr_dict[in_key[1:]] = attr_dict[in_key]
else:
new_attr_dict[
'num_channels'] = self.context.expand * attr_dict[
'_num_channels']
new_attr_dict[in_key[1:]] = int(self.context.expand *
attr_dict[in_key])
### last super convolution transpose
if idx == last_weight_layer_idx:
new_attr_dict['num_filters'] = attr_dict['_num_filters']
new_attr_dict[out_key[1:]] = attr_dict[out_key]
else:
new_attr_dict[
'num_filters'] = self.context.expand * attr_dict[
'_num_filters']
new_attr_dict[out_key[1:]] = int(self.context.expand *
attr_dict[out_key])
new_attr_dict['candidate_config'].update({
'expand_ratio': self.context.expand_ratio
})
elif self.context.channel:
if attr_dict['_groups'] != None and (
int(attr_dict['_groups']) ==
int(attr_dict['_num_channels'])):
int(attr_dict['_groups']) == int(attr_dict[in_key])
):
### depthwise conv_transpose
_logger.warn(
"If convolution is a depthwise conv_transpose, output channel " \
......@@ -252,29 +324,33 @@ class Convert:
cur_channel = self.context.channel[0]
self.context.channel = self.context.channel[1:]
if idx == first_weight_layer_idx:
new_attr_dict['num_channels'] = attr_dict[
'_num_channels']
new_attr_dict[in_key[1:]] = attr_dict[in_key]
else:
new_attr_dict['num_channels'] = max(pre_channel)
new_attr_dict[in_key[1:]] = max(pre_channel)
if idx == last_weight_layer_idx:
new_attr_dict['num_filters'] = attr_dict['_num_filters']
new_attr_dict[out_key[1:]] = attr_dict[out_key]
else:
new_attr_dict['num_filters'] = max(cur_channel)
new_attr_dict[out_key[1:]] = max(cur_channel)
new_attr_dict['candidate_config'].update({
'channel': cur_channel
})
pre_channel = cur_channel
else:
new_attr_dict['num_filters'] = attr_dict['_num_filters']
new_attr_dict['num_channels'] = attr_dict['_num_channels']
new_attr_dict[in_key[1:]] = attr_dict[in_key]
new_attr_dict[out_key[1:]] = attr_dict[out_key]
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
if attr == 'weight_attr':
new_attr_dict[attr] = attr_dict['_param_attr']
elif attr == 'output_padding':
new_attr_dict[attr] = attr_dict[attr]
else:
new_attr_dict[attr] = attr_dict['_' + attr]
del layer
if new_attr_dict['output_size'] == []:
if getattr(new_attr_dict, 'output_size', None) == []:
new_attr_dict['output_size'] = None
if attr_dict['_groups'] == None or int(attr_dict[
......@@ -282,17 +358,15 @@ class Convert:
### standard conv_transpose
layer = Block(
SuperConv2DTranspose(**new_attr_dict), key=key)
elif int(attr_dict['_groups']) == int(attr_dict[
'_num_channels']):
elif int(attr_dict['_groups']) == int(attr_dict[in_key]):
# if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
# channel in candidate_config = in_channel_list
if 'channel' in new_attr_dict['candidate_config']:
new_attr_dict['num_channels'] = max(cur_channel)
new_attr_dict['num_filters'] = new_attr_dict[
'num_channels']
new_attr_dict[in_key[1:]] = max(cur_channel)
new_attr_dict[out_key[1:]] = new_attr_dict[in_key[1:]]
new_attr_dict['candidate_config'][
'channel'] = cur_channel
new_attr_dict['groups'] = new_attr_dict['num_channels']
new_attr_dict['groups'] = new_attr_dict[in_key[1:]]
layer = Block(
SuperDepthwiseConv2DTranspose(**new_attr_dict), key=key)
else:
......@@ -306,25 +380,39 @@ class Convert:
getattr(self.context, 'channel', None) != None):
attr_dict = layer.__dict__
key = attr_dict['_full_name']
### TODO(paddle): add _param_attr and _bias_attr as private variable of Linear
#new_attr_name = ['_act', '_dtype', '_param_attr', '_bias_attr']
new_attr_name = ['_act', '_dtype']
if pd_ver == 185:
new_attr_name = ['param_attr', 'bias_attr', 'act', 'dtype']
else:
new_attr_name = ['weight_attr', 'bias_attr']
in_nc, out_nc = layer._parameters['weight'].shape
new_attr_dict = dict()
new_attr_dict = dict.fromkeys(new_attr_name, None)
new_attr_dict['candidate_config'] = dict()
if pd_ver == 185:
new_attr_dict['input_dim'] = None
new_attr_dict['output_dim'] = None
else:
new_attr_dict['in_features'] = None
new_attr_dict['out_features'] = None
in_key = '_input_dim' if '_input_dim' in attr_dict.keys(
) else '_in_features'
out_key = '_output_dim' if '_output_dim' in attr_dict.keys(
) else '_out_features'
attr_dict[in_key] = in_nc
attr_dict[out_key] = out_nc
if self.context.expand:
if idx == first_weight_layer_idx:
new_attr_dict['input_dim'] = int(in_nc)
new_attr_dict[in_key[1:]] = int(attr_dict[in_key])
else:
new_attr_dict['input_dim'] = self.context.expand * int(
in_nc)
new_attr_dict[in_key[1:]] = int(self.context.expand *
attr_dict[in_key])
if idx == last_weight_layer_idx:
new_attr_dict['output_dim'] = int(out_nc)
new_attr_dict[out_key[1:]] = int(attr_dict[out_key])
else:
new_attr_dict['output_dim'] = self.context.expand * int(
out_nc)
new_attr_dict[out_key[1:]] = int(self.context.expand *
attr_dict[out_key])
new_attr_dict['candidate_config'].update({
'expand_ratio': self.context.expand_ratio
})
......@@ -332,31 +420,34 @@ class Convert:
cur_channel = self.context.channel[0]
self.context.channel = self.context.channel[1:]
if idx == first_weight_layer_idx:
new_attr_dict['input_dim'] = int(in_nc)
new_attr_dict[in_key[1:]] = int(attr_dict[in_key])
else:
new_attr_dict['input_dim'] = max(pre_channel)
new_attr_dict[in_key[1:]] = max(pre_channel)
if idx == last_weight_layer_idx:
new_attr_dict['output_dim'] = int(out_nc)
new_attr_dict[out_key[1:]] = int(attr_dict[out_key])
else:
new_attr_dict['output_dim'] = max(cur_channel)
new_attr_dict[out_key[1:]] = max(cur_channel)
new_attr_dict['candidate_config'].update({
'channel': cur_channel
})
pre_channel = cur_channel
else:
new_attr_dict['input_dim'] = int(in_nc)
new_attr_dict['output_dim'] = int(out_nc)
new_attr_dict[in_key[1:]] = int(attr_dict[in_key])
new_attr_dict[out_key[1:]] = int(attr_dict[out_key])
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
new_attr_dict[attr] = attr_dict['_' + attr]
del layer, attr_dict
layer = Block(SuperLinear(**new_attr_dict), key=key)
model[idx] = layer
elif isinstance(layer, InstanceNorm) and (
elif isinstance(
layer,
getattr(nn, 'InstanceNorm2D',
paddle.fluid.dygraph.nn.InstanceNorm)) and (
getattr(self.context, 'expand', None) != None or
getattr(self.context, 'channel', None) != None):
# num_features in InstanceNorm don't change after last weight operators
......@@ -364,24 +455,38 @@ class Convert:
continue
attr_dict = layer.__dict__
if pd_ver == 185:
new_attr_name = [
'_param_attr', '_bias_attr', '_dtype', '_epsilon'
'bias_attr', 'epsilon', 'param_attr', 'dtype'
]
new_attr_dict = dict()
else:
new_attr_name = ['bias_attr', 'epsilon', 'weight_attr']
new_attr_dict = dict.fromkeys(new_attr_name, None)
if pd_ver == 185:
new_attr_dict['num_channels'] = None
else:
new_attr_dict['num_features'] = None
new_key = '_num_channels' if '_num_channels' in new_attr_dict.keys(
) else '_num_features'
### 10 is a default channel in the case of weight_attr=False, in this condition, num of channels if useless, so give it arbitrarily.
attr_dict[new_key] = layer._parameters['scale'].shape[0] if len(
layer._parameters) != 0 else 10
if self.context.expand:
new_attr_dict['num_channels'] = self.context.expand * int(
layer._parameters['scale'].shape[0])
new_attr_dict[new_key[1:]] = int(self.context.expand *
attr_dict[new_key])
elif self.context.channel:
new_attr_dict['num_channels'] = max(cur_channel)
new_attr_dict[new_key[1:]] = max(cur_channel)
else:
new_attr_dict['num_channels'] = attr_dict['_num_channels']
new_attr_dict[new_key[1:]] = attr_dict[new_key]
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
new_attr_dict[attr] = attr_dict['_' + attr]
del layer, attr_dict
layer = SuperInstanceNorm(**new_attr_dict)
layer = getattr(layers, 'SuperInstanceNorm2D',
'SuperInstanceNorm')(**new_attr_dict)
model[idx] = layer
elif isinstance(layer, LayerNorm) and (
......@@ -392,15 +497,19 @@ class Convert:
continue
attr_dict = layer.__dict__
new_attr_name = [
'_scale', '_shift', '_param_attr', '_bias_attr', '_act',
'_dtype', '_epsilon'
new_attr_name = ['epsilon', 'bias_attr']
if pd_ver == 185:
new_attr_name += [
'scale', 'shift', 'param_attr', 'act', 'dtype'
]
new_attr_dict = dict()
else:
new_attr_name += ['weight_attr']
new_attr_dict = dict.fromkeys(new_attr_name, None)
new_attr_dict['normalized_shape'] = None
if self.context.expand:
new_attr_dict[
'normalized_shape'] = self.context.expand * int(
attr_dict['_normalized_shape'][0])
new_attr_dict['normalized_shape'] = int(
self.context.expand * attr_dict['_normalized_shape'][0])
elif self.context.channel:
new_attr_dict['normalized_shape'] = max(cur_channel)
else:
......@@ -408,7 +517,7 @@ class Convert:
'_normalized_shape']
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
new_attr_dict[attr] = attr_dict['_' + attr]
del layer, attr_dict
layer = SuperLayerNorm(**new_attr_dict)
......@@ -419,18 +528,32 @@ class Convert:
getattr(self.context, 'channel', None) != None):
attr_dict = layer.__dict__
key = attr_dict['_full_name']
new_attr_name = [
'_is_sparse', '_is_distributed', '_padding_idx',
'_param_attr', '_dtype'
new_attr_name = ['padding_idx', ]
if pd_ver == 185:
new_attr_name += [
'size', 'is_sparse', 'is_distributed', 'param_attr',
'dtype'
]
else:
new_attr_name += [
'num_embeddings', 'embedding_dim', 'sparse',
'weight_attr', 'name'
]
new_attr_dict = dict()
new_attr_dict = dict.fromkeys(new_attr_name, None)
new_attr_dict['candidate_config'] = dict()
bef_size = attr_dict['_size']
if self.context.expand:
if pd_ver == 185:
new_attr_dict['size'] = [
bef_size[0], self.context.expand * bef_size[1]
bef_size[0], int(self.context.expand * bef_size[1])
]
else:
new_attr_dict['num_embeddings'] = attr_dict[
'_num_embeddings']
new_attr_dict['embedding_dim'] = int(
self.context.expand * attr_dict['_embedding_dim'])
new_attr_dict['candidate_config'].update({
'expand_ratio': self.context.expand_ratio
})
......@@ -438,23 +561,52 @@ class Convert:
elif self.context.channel:
cur_channel = self.context.channel[0]
self.context.channel = self.context.channel[1:]
if pd_ver == 185:
new_attr_dict['size'] = [bef_size[0], max(cur_channel)]
else:
new_attr_dict['num_embeddings'] = attr_dict[
'_num_embeddings']
new_attr_dict['embedding_dim'] = max(cur_channel)
new_attr_dict['candidate_config'].update({
'channel': cur_channel
})
pre_channel = cur_channel
else:
if pf_ver == 185:
new_attr_dict['size'] = bef_size
else:
new_attr_dict['num_embeddings'] = attr_dict[
'_num_embeddings']
new_attr_dict['embedding_dim'] = attr_dict[
'_embedding_dim']
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
new_attr_dict[attr] = attr_dict['_' + attr]
del layer, attr_dict
layer = Block(SuperEmbedding(**new_attr_dict), key=key)
model[idx] = layer
return model
def split_prefix(net, name_list):
if len(name_list) > 1:
net = split_prefix(getattr(net, name_list[0]), name_list[1:])
elif len(name_list) == 1:
net = getattr(net, name_list[0])
else:
raise NotImplementedError("name error")
return net
if isinstance(network, Layer):
for idx, (name, sublayer) in enumerate(network.named_sublayers()):
if len(name.split('.')) > 1:
net = split_prefix(network, name.split('.')[:-1])
else:
net = network
setattr(net, name.split('.')[-1], model[idx])
return network
class supernet:
......@@ -474,12 +626,16 @@ class supernet:
self.expand = max(self.expand_ratio)
elif isinstance(self.expand_ratio, int):
self.expand = self.expand_ratio
if 'channel' not in kwargs.keys():
self.channel = None
def __enter__(self):
return Convert(self)
def __exit__(self, exc_type, exc_val, exc_tb):
pass
self.expand = None
self.channel = None
self.kernel_size = None
#def ofa_supernet(kernel_size, expand_ratio):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import logging
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.fluid.core as core
from ...common import get_logger
from .utils.utils import compute_start_end, get_same_padding, convert_to_list
__all__ = [
'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D',
'SuperBatchNorm2D', 'SuperLinear', 'SuperInstanceNorm2D', 'Block',
'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose',
'SuperDepthwiseConv2DTranspose', 'SuperLayerNorm', 'SuperEmbedding'
]
_logger = get_logger(__name__, level=logging.INFO)
### TODO: if task is elastic width, need to add re_organize_middle_weight in 1x1 conv in MBBlock
_cnt = 0
def counter():
global _cnt
_cnt += 1
return _cnt
class BaseBlock(paddle.nn.Layer):
def __init__(self, key=None):
super(BaseBlock, self).__init__()
if key is not None:
self._key = str(key)
else:
self._key = self.__class__.__name__ + str(counter())
# set SuperNet class
def set_supernet(self, supernet):
self.__dict__['supernet'] = supernet
@property
def key(self):
return self._key
class Block(BaseBlock):
"""
Model is composed of nest blocks.
Parameters:
fn(Layer): instance of super layers, such as: SuperConv2D(3, 5, 3).
key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None.
"""
def __init__(self, fn, fixed=False, key=None):
super(Block, self).__init__(key)
self.fn = fn
self.fixed = fixed
self.candidate_config = self.fn.candidate_config
def forward(self, *inputs, **kwargs):
out = self.supernet.layers_forward(self, *inputs, **kwargs)
return out
class SuperConv2D(nn.Conv2D):
"""
This interface is used to construct a callable object of the ``SuperConv2D`` class.
The difference between ```SuperConv2D``` and ```Conv2D``` is: ```SuperConv2D``` need
to feed a config dictionary with the format of {'channel', num_of_channel} represents
the channels of the outputs, used to change the first dimension of weight and bias,
only train the first channels of the weight and bias.
Note: the channel in config need to less than first defined.
The super convolution2D layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input and
Output are in NCHW format, where N is batch size, C is the number of
the feature map, H is the height of the feature map, and W is the width of the feature map.
Filter's shape is [MCHW] , where M is the number of output feature map,
C is the number of input feature map, H is the height of the filter,
and W is the width of the filter. If the groups is greater than 1,
C will equal the number of input feature map divided by the groups.
Please refer to UFLDL's `convolution
<http://ufldl.stanford.edu/tutorial/supervised/FeatureExtractionUsingConvolution/>`_
for more details.
If bias attribution and activation type are provided, bias is added to the
output of the convolution, and the corresponding activation function is
applied to the final result.
For each input :math:`X`, the equation is:
.. math::
Out = \\sigma (W \\ast X + b)
Where:
* :math:`X`: Input value, a ``Tensor`` with NCHW format.
* :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] .
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1].
* :math:`\\sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Parameters:
num_channels(int): The number of channels in the input image.
num_filters(int): The number of filter. It is as same as the output
feature map.
filter_size (int or tuple): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
candidate_config(dict, optional): Dictionary descripts candidate config of this layer,
such as {'kernel_size': (3, 5, 7), 'channel': (4, 6, 8)}, means the kernel size of
this layer can be choose from (3, 5, 7), the key of candidate_config
only can be 'kernel_size', 'channel' and 'expand_ratio', 'channel' and 'expand_ratio'
CANNOT be set at the same time. Default: None.
transform_kernel(bool, optional): Whether to use transform matrix to transform a large filter
to a small filter. Default: False.
stride (int or tuple, optional): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: 1.
padding (int or tuple, optional): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: 0.
dilation (int or tuple, optional): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: 1.
groups (int, optional): The groups number of the Conv2d Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: 1.
param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter)
of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn (bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True.
act (str, optional): Activation type, if it is set to None, activation is not appended.
Default: None.
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
Attribute:
**weight** (Parameter): the learnable weights of filter of this layer.
**bias** (Parameter or None): the learnable bias of this layer.
Returns:
None
Raises:
ValueError: if ``use_cudnn`` is not a bool value.
Examples:
.. code-block:: python
import paddle
from paddleslim.nas.ofa.layers import SuperConv2D
import numpy as np
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
super_conv2d = SuperConv2D(3, 10, 3)
config = {'channel': 5}
data = paddle.to_variable(data)
conv = super_conv2d(data, config)
"""
### NOTE: filter_size, num_channels and num_filters must be the max of candidate to define a largest network.
def __init__(self,
in_channels,
out_channels,
kernel_size,
candidate_config={},
transform_kernel=False,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode='zeros',
weight_attr=None,
bias_attr=None,
data_format='NCHW'):
super(SuperConv2D, self).__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
padding_mode=padding_mode,
dilation=dilation,
groups=groups,
weight_attr=weight_attr,
bias_attr=bias_attr,
data_format=data_format)
self.candidate_config = candidate_config
if len(candidate_config.items()) != 0:
for k, v in candidate_config.items():
candidate_config[k] = list(set(v))
self.ks_set = candidate_config[
'kernel_size'] if 'kernel_size' in candidate_config else None
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.channel = candidate_config[
'channel'] if 'channel' in candidate_config else None
self.base_channel = self._out_channels
if self.expand_ratio != None:
self.base_channel = int(self._out_channels / max(self.expand_ratio))
self.transform_kernel = transform_kernel
if self.ks_set != None:
self.ks_set.sort()
if self.transform_kernel != False:
scale_param = dict()
### create parameter to transform kernel
for i in range(len(self.ks_set) - 1):
ks_small = self.ks_set[i]
ks_large = self.ks_set[i + 1]
param_name = '%dto%d_matrix' % (ks_large, ks_small)
ks_t = ks_small**2
scale_param[param_name] = self.create_parameter(
attr=paddle.ParamAttr(
name=self._full_name + param_name,
initializer=nn.initializer.Assign(np.eye(ks_t))),
shape=(ks_t, ks_t),
dtype=self._dtype)
for name, param in scale_param.items():
setattr(self, name, param)
def get_active_filter(self, in_nc, out_nc, kernel_size):
start, end = compute_start_end(self._kernel_size[0], kernel_size)
### if NOT transform kernel, intercept a center filter with kernel_size from largest filter
filters = self.weight[:out_nc, :in_nc, start:end, start:end]
if self.transform_kernel != False and kernel_size < self._kernel_size[
0]:
### if transform kernel, then use matrix to transform
start_filter = self.weight[:out_nc, :in_nc, :, :]
for i in range(len(self.ks_set) - 1, 0, -1):
src_ks = self.ks_set[i]
if src_ks <= kernel_size:
break
target_ks = self.ks_set[i - 1]
start, end = compute_start_end(src_ks, target_ks)
_input_filter = start_filter[:, :, start:end, start:end]
_input_filter = paddle.reshape(
_input_filter,
shape=[(_input_filter.shape[0] * _input_filter.shape[1]),
-1])
_input_filter = paddle.matmul(
_input_filter,
self.__getattr__('%dto%d_matrix' %
(src_ks, target_ks)), False, False)
_input_filter = paddle.reshape(
_input_filter,
shape=[
filters.shape[0], filters.shape[1], target_ks, target_ks
])
start_filter = _input_filter
filters = start_filter
return filters
def get_groups_in_out_nc(self, in_nc, out_nc):
### standard conv
return self._groups, in_nc, out_nc
def forward(self, input, kernel_size=None, expand_ratio=None, channel=None):
self.cur_config = {
'kernel_size': kernel_size,
'expand_ratio': expand_ratio,
'channel': channel
}
in_nc = int(input.shape[1])
assert (
expand_ratio == None or channel == None
), "expand_ratio and channel CANNOT be NOT None at the same time."
if expand_ratio != None:
out_nc = int(expand_ratio * self.base_channel)
elif channel != None:
out_nc = int(channel)
else:
out_nc = self._out_channels
ks = int(self._kernel_size[0]) if kernel_size == None else int(
kernel_size)
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
out_nc)
weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks)
if kernel_size != None or 'kernel_size' in self.candidate_config.keys():
padding = convert_to_list(get_same_padding(ks), 2)
else:
padding = self._padding
if self.bias is not None:
bias = self.bias[:out_nc]
else:
bias = self.bias
out = F.conv2d(
input,
weight,
bias=bias,
stride=self._stride,
padding=padding,
dilation=self._dilation,
groups=self._groups,
data_format=self._data_format)
return out
class SuperGroupConv2D(SuperConv2D):
def get_groups_in_out_nc(self, in_nc, out_nc):
### groups convolution
### conv: weight: (Cout, Cin/G, Kh, Kw)
groups = self._groups
in_nc = int(in_nc // groups)
return groups, in_nc, out_nc
class SuperDepthwiseConv2D(SuperConv2D):
### depthwise convolution
def get_groups_in_out_nc(self, in_nc, out_nc):
if in_nc != out_nc:
_logger.debug(
"input channel and output channel in depthwise conv is different, change output channel to input channel! origin channel:(in_nc {}, out_nc {}): ".
format(in_nc, out_nc))
groups = in_nc
out_nc = in_nc
return groups, in_nc, out_nc
class SuperConv2DTranspose(nn.Conv2DTranspose):
"""
This interface is used to construct a callable object of the ``SuperConv2DTranspose``
class.
The difference between ```SuperConv2DTranspose``` and ```Conv2DTranspose``` is:
```SuperConv2DTranspose``` need to feed a config dictionary with the format of
{'channel', num_of_channel} represents the channels of the outputs, used to change
the first dimension of weight and bias, only train the first channels of the weight
and bias.
Note: the channel in config need to less than first defined.
The super convolution2D transpose layer calculates the output based on the input,
filter, and dilations, strides, paddings. Input and output
are in NCHW format. Where N is batch size, C is the number of feature map,
H is the height of the feature map, and W is the width of the feature map.
Filter's shape is [MCHW] , where M is the number of input feature map,
C is the number of output feature map, H is the height of the filter,
and W is the width of the filter. If the groups is greater than 1,
C will equal the number of input feature map divided by the groups.
If bias attribution and activation type are provided, bias is added to
the output of the convolution, and the corresponding activation function
is applied to the final result.
The details of convolution transpose layer, please refer to the following explanation and references
`conv2dtranspose <http://www.matthewzeiler.com/wp-content/uploads/2017/07/cvpr2010.pdf>`_ .
For each input :math:`X`, the equation is:
.. math::
Out = \sigma (W \\ast X + b)
Where:
* :math:`X`: Input value, a ``Tensor`` with NCHW format.
* :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] .
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1].
* :math:`\\sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{in}, C_{out}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\
W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 \\\\
H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] ) \\\\
W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] )
Parameters:
num_channels(int): The number of channels in the input image.
num_filters(int): The number of the filter. It is as same as the output
feature map.
filter_size(int or tuple): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
candidate_config(dict, optional): Dictionary descripts candidate config of this layer,
such as {'kernel_size': (3, 5, 7), 'channel': (4, 6, 8)}, means the kernel size of
this layer can be choose from (3, 5, 7), the key of candidate_config
only can be 'kernel_size', 'channel' and 'expand_ratio', 'channel' and 'expand_ratio'
CANNOT be set at the same time. Default: None.
transform_kernel(bool, optional): Whether to use transform matrix to transform a large filter
to a small filter. Default: False.
output_size(int or tuple, optional): The output image size. If output size is a
tuple, it must contain two integers, (image_H, image_W). None if use
filter_size, padding, and stride to calculate output_size.
if output_size and filter_size are specified at the same time, They
should follow the formula above. Default: None.
padding(int or tuple, optional): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: 0.
stride(int or tuple, optional): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: 1.
dilation(int or tuple, optional): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: 1.
groups(int, optional): The groups number of the Conv2d transpose layer. Inspired by
grouped convolution in Alex Krizhevsky's Deep CNN paper, in which
when group=2, the first half of the filters is only connected to the
first half of the input channels, while the second half of the
filters is only connected to the second half of the input channels.
Default: 1.
param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter)
of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d_transpose.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True.
act (str, optional): Activation type, if it is set to None, activation is not appended.
Default: None.
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
Attribute:
**weight** (Parameter): the learnable weights of filters of this layer.
**bias** (Parameter or None): the learnable bias of this layer.
Returns:
None
Examples:
.. code-block:: python
import paddle
import numpy as np
from paddleslim.nas.ofa.layers import SuperConv2DTranspose
data = np.random.random((3, 32, 32, 5)).astype('float32')
config = {'channel': 5}
super_convtranspose = SuperConv2DTranspose(num_channels=32, num_filters=10, filter_size=3)
ret = super_convtranspose(paddle.to_variable(data), config)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
candidate_config={},
transform_kernel=False,
stride=1,
padding=0,
output_padding=0,
dilation=1,
groups=1,
weight_attr=None,
bias_attr=None,
data_format="NCHW"):
super(SuperConv2DTranspose, self).__init__(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=output_padding,
groups=groups,
weight_attr=weight_attr,
bias_attr=bias_attr,
data_format=data_format)
self.candidate_config = candidate_config
if len(self.candidate_config.items()) != 0:
for k, v in candidate_config.items():
candidate_config[k] = list(set(v))
self.ks_set = candidate_config[
'kernel_size'] if 'kernel_size' in candidate_config else None
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.channel = candidate_config[
'channel'] if 'channel' in candidate_config else None
self.base_channel = self._out_channels
if self.expand_ratio:
self.base_channel = int(self._out_channels / max(self.expand_ratio))
self.transform_kernel = transform_kernel
if self.ks_set != None:
self.ks_set.sort()
if self.transform_kernel != False:
scale_param = dict()
### create parameter to transform kernel
for i in range(len(self.ks_set) - 1):
ks_small = self.ks_set[i]
ks_large = self.ks_set[i + 1]
param_name = '%dto%d_matrix' % (ks_large, ks_small)
ks_t = ks_small**2
scale_param[param_name] = self.create_parameter(
attr=paddle.ParamAttr(
name=self._full_name + param_name,
initializer=nn.initializer.Assign(np.eye(ks_t))),
shape=(ks_t, ks_t),
dtype=self._dtype)
for name, param in scale_param.items():
setattr(self, name, param)
def get_active_filter(self, in_nc, out_nc, kernel_size):
start, end = compute_start_end(self._kernel_size[0], kernel_size)
filters = self.weight[:in_nc, :out_nc, start:end, start:end]
if self.transform_kernel != False and kernel_size < self._kernel_size[
0]:
start_filter = self.weight[:in_nc, :out_nc, :, :]
for i in range(len(self.ks_set) - 1, 0, -1):
src_ks = self.ks_set[i]
if src_ks <= kernel_size:
break
target_ks = self.ks_set[i - 1]
start, end = compute_start_end(src_ks, target_ks)
_input_filter = start_filter[:, :, start:end, start:end]
_input_filter = paddle.reshape(
_input_filter,
shape=[(_input_filter.shape[0] * _input_filter.shape[1]),
-1])
_input_filter = paddle.matmul(
_input_filter,
self.__getattr__('%dto%d_matrix' %
(src_ks, target_ks)), False, False)
_input_filter = paddle.reshape(
_input_filter,
shape=[
filters.shape[0], filters.shape[1], target_ks, target_ks
])
start_filter = _input_filter
filters = start_filter
return filters
def get_groups_in_out_nc(self, in_nc, out_nc):
### standard conv
return self._groups, in_nc, out_nc
def forward(self,
input,
output_size=None,
kernel_size=None,
expand_ratio=None,
channel=None):
self.cur_config = {
'kernel_size': kernel_size,
'expand_ratio': expand_ratio,
'channel': channel
}
in_nc = int(input.shape[1])
assert (
expand_ratio == None or channel == None
), "expand_ratio and channel CANNOT be NOT None at the same time."
if expand_ratio != None:
out_nc = int(expand_ratio * self.base_channel)
elif channel != None:
out_nc = int(channel)
else:
out_nc = self._out_channels
ks = int(self._kernel_size[0]) if kernel_size == None else int(
kernel_size)
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
out_nc)
weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks)
if kernel_size != None or 'kernel_size' in self.candidate_config.keys():
padding = convert_to_list(get_same_padding(ks), 2)
else:
padding = self._padding
if output_size is None:
output_padding = self.output_padding
else:
output_padding = 0
if self.bias is not None:
bias = self.bias[:out_nc]
else:
bias = self.bias
out = F.conv2d_transpose(
input,
weight,
bias=bias,
padding=padding,
output_padding=output_padding,
stride=self._stride,
dilation=self._dilation,
groups=self._groups,
output_size=output_size,
data_format=self._data_format)
return out
class SuperGroupConv2DTranspose(SuperConv2DTranspose):
def get_groups_in_out_nc(self, in_nc, out_nc):
### groups convolution
### groups conv transpose: weight: (Cin, Cout/G, Kh, Kw)
groups = self._groups
out_nc = int(out_nc // groups)
return groups, in_nc, out_nc
class SuperDepthwiseConv2DTranspose(SuperConv2DTranspose):
def get_groups_in_out_nc(self, in_nc, out_nc):
if in_nc != out_nc:
_logger.debug(
"input channel and output channel in depthwise conv transpose is different, change output channel to input channel! origin channel:(in_nc {}, out_nc {}): ".
format(in_nc, out_nc))
groups = in_nc
out_nc = in_nc
return groups, in_nc, out_nc
### NOTE: only search channel, write for GAN-compression, maybe change to SuperDepthwiseConv and SuperConv after.
class SuperSeparableConv2D(nn.Layer):
"""
This interface is used to construct a callable object of the ``SuperSeparableConv2D``
class.
The difference between ```SuperSeparableConv2D``` and ```SeparableConv2D``` is:
```SuperSeparableConv2D``` need to feed a config dictionary with the format of
{'channel', num_of_channel} represents the channels of the first conv's outputs and
the second conv's inputs, used to change the first dimension of weight and bias,
only train the first channels of the weight and bias.
The architecture of super separable convolution2D op is [Conv2D, norm layer(may be BatchNorm2D
or InstanceNorm2D), Conv2D]. The first conv is depthwise conv, the filter number is input channel
multiply scale_factor, the group is equal to the number of input channel. The second conv
is standard conv, which filter size and stride size are 1.
Parameters:
num_channels(int): The number of channels in the input image.
num_filters(int): The number of the second conv's filter. It is as same as the output
feature map.
filter_size(int or tuple): The first conv's filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
padding(int or tuple, optional): The first conv's padding size. If padding is a tuple,
it must contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: 0.
stride(int or tuple, optional): The first conv's stride size. If stride is a tuple,
it must contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: 1.
dilation(int or tuple, optional): The first conv's dilation size. If dilation is a tuple,
it must contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: 1.
norm_layer(class): The normalization layer between two convolution. Default: InstanceNorm2D.
bias_attr (ParamAttr or bool, optional): The attribute for the bias of convolution.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, convolution
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
scale_factor(float): The scale factor of the first conv's output channel. Default: 1.
Returns:
None
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
candidate_config={},
stride=1,
padding=0,
dilation=1,
norm_layer=nn.InstanceNorm2D,
bias_attr=None,
scale_factor=1):
super(SuperSeparableConv2D, self).__init__()
self.conv = nn.LayerList([
nn.Conv2D(
in_channels=in_channels,
out_channels=in_channels * scale_factor,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=in_channels,
bias_attr=bias_attr)
])
self.conv.extend([norm_layer(in_channels * scale_factor)])
self.conv.extend([
nn.Conv2D(
in_channels=in_channels * scale_factor,
out_channels=out_channels,
kernel_size=1,
stride=1,
bias_attr=bias_attr)
])
self.candidate_config = candidate_config
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self.conv[0]._out_channels
if self.expand_ratio != None:
self.base_output_dim = int(self.conv[0]._out_channels /
max(self.expand_ratio))
def forward(self, input, expand_ratio=None, channel=None):
self.cur_config = {'expand_ratio': expand_ratio, 'channel': channel}
in_nc = int(input.shape[1])
assert (
expand_ratio == None or channel == None
), "expand_ratio and channel CANNOT be NOT None at the same time."
if expand_ratio != None:
out_nc = int(expand_ratio * self.base_output_dim)
elif channel != None:
out_nc = int(channel)
else:
out_nc = self.conv[0]._out_channels
weight = self.conv[0].weight[:in_nc]
### conv1
if self.conv[0].bias is not None:
bias = self.conv[0].bias[:in_nc]
else:
bias = self.conv[0].bias
conv0_out = F.conv2d(
input,
weight,
bias,
stride=self.conv[0]._stride,
padding=self.conv[0]._padding,
dilation=self.conv[0]._dilation,
groups=in_nc,
data_format=self.conv[0]._data_format)
norm_out = self.conv[1](conv0_out)
weight = self.conv[2].weight[:out_nc, :in_nc, :, :]
if self.conv[2].bias is not None:
bias = self.conv[2].bias[:out_nc]
else:
bias = self.conv[2].bias
conv1_out = F.conv2d(
norm_out,
weight,
bias,
stride=self.conv[2]._stride,
padding=self.conv[2]._padding,
dilation=self.conv[2]._dilation,
groups=self.conv[2]._groups,
data_format=self.conv[2]._data_format)
return conv1_out
class SuperLinear(nn.Linear):
"""
"""
def __init__(self,
in_features,
out_features,
candidate_config={},
weight_attr=None,
bias_attr=None,
name=None):
super(SuperLinear, self).__init__(in_features, out_features,
weight_attr, bias_attr, name)
self._weight_attr = weight_attr
self._bias_attr = bias_attr
self._in_features = in_features
self._out_features = out_features
self.candidate_config = candidate_config
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self._out_features
if self.expand_ratio != None:
self.base_output_dim = int(self._out_features /
max(self.expand_ratio))
def forward(self, input, expand_ratio=None, channel=None):
self.cur_config = {'expand_ratio': expand_ratio, 'channel': channel}
### weight: (Cin, Cout)
in_nc = int(input.shape[-1])
assert (
expand_ratio == None or channel == None
), "expand_ratio and channel CANNOT be NOT None at the same time."
if expand_ratio != None:
out_nc = int(expand_ratio * self.base_output_dim)
elif channel != None:
out_nc = int(channel)
else:
out_nc = self._out_features
weight = self.weight[:in_nc, :out_nc]
if self._bias_attr != False:
bias = self.bias[:out_nc]
else:
bias = self.bias
out = F.linear(x=input, weight=weight, bias=bias, name=self.name)
return out
class SuperBatchNorm2D(nn.BatchNorm2D):
"""
add comment
"""
def __init__(self,
num_features,
momentum=0.9,
epsilon=1e-05,
weight_attr=None,
bias_attr=None,
data_format='NCHW',
name=None):
super(SuperBatchNorm2D, self).__init__(num_features, momentum, epsilon,
weight_attr, bias_attr,
data_format, name)
def forward(self, input):
self._check_data_format(self._data_format)
self._check_input_dim(input)
feature_dim = int(input.shape[1])
weight = self.weight[:feature_dim]
bias = self.bias[:feature_dim]
mean = self._mean[:feature_dim]
variance = self._variance[:feature_dim]
return F.batch_norm(
input,
mean,
variance,
weight=weight,
bias=bias,
training=self.training,
momentum=self._momentum,
epsilon=self._epsilon,
data_format=self._data_format)
class SuperInstanceNorm2D(nn.InstanceNorm2D):
"""
"""
def __init__(self,
num_features,
epsilon=1e-05,
momentum=0.9,
weight_attr=None,
bias_attr=None,
data_format='NCHW',
name=None):
super(SuperInstanceNorm2D, self).__init__(num_features, epsilon,
momentum, weight_attr,
bias_attr, data_format, name)
def forward(self, input):
self._check_input_dim(input)
feature_dim = int(input.shape[1])
if self._weight_attr == False and self._bias_attr == False:
scale = None
bias = None
else:
scale = self.scale[:feature_dim]
bias = self.bias[:feature_dim]
return F.instance_norm(input, scale, bias, eps=self._epsilon)
class SuperLayerNorm(nn.LayerNorm):
def __init__(self,
normalized_shape,
epsilon=1e-05,
weight_attr=None,
bias_attr=None,
name=None):
super(SuperLayerNorm, self).__init__(normalized_shape, epsilon,
weight_attr, bias_attr, name)
def forward(self, input):
### TODO(ceci3): fix if normalized_shape is not a single number
input_ndim = len(list(input.shape))
normalized_ndim = len(self._normalized_shape)
begin_norm_axis = input_ndim - normalized_ndim
feature_dim = int(input.shape[-1])
if self._weight_attr != False:
weight = self.weight[:feature_dim]
else:
weight = None
if self._bias_attr != False:
bias = self.bias[:feature_dim]
else:
bias = None
out, _, _ = core.ops.layer_norm(input, weight, bias, 'epsilon',
self._epsilon, 'begin_norm_axis',
begin_norm_axis)
return out
class SuperEmbedding(nn.Embedding):
def __init__(self,
num_embeddings,
embedding_dim,
candidate_config={},
padding_idx=None,
sparse=False,
weight_attr=None,
name=None):
super(SuperEmbedding, self).__init__(num_embeddings, embedding_dim,
sparse, weight_attr, name)
self.candidate_config = candidate_config
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = self._embedding_dim
if self.expand_ratio != None:
self.base_output_dim = int(self._embedding_dim /
max(self.expand_ratio))
def forward(self, input, expand_ratio=None, channel=None):
assert (
expand_ratio == None or channel == None
), "expand_ratio and channel CANNOT be NOT None at the same time."
if expand_ratio != None:
out_nc = int(expand_ratio * self.base_output_dim)
elif channel != None:
out_nc = int(channel)
else:
out_nc = self._embedding_dim
weight = self.weight[:, :out_nc]
return F.embedding(
input,
weight=weight,
padding_idx=self._padding_idx,
sparse=self._sparse,
name=self._name)
......@@ -16,10 +16,15 @@ import logging
import numpy as np
from collections import namedtuple
import paddle
#import paddle.nn as nn
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D
from .layers import BaseBlock, Block, SuperConv2D, SuperBatchNorm
from .utils.utils import get_paddle_version
pd_ver = get_paddle_version()
if pd_ver == 185:
from .layers import BaseBlock, SuperConv2D
Layer = paddle.fluid.dygraph.Layer
else:
from .layers_new import BaseBlock, SuperConv2D
Layer = paddle.nn.Layer
from .utils.utils import search_idx
from ...common import get_logger
......@@ -40,7 +45,7 @@ DistillConfig = namedtuple('DistillConfig', [
DistillConfig.__new__.__defaults__ = (None, ) * len(DistillConfig._fields)
class OFABase(fluid.dygraph.Layer):
class OFABase(Layer):
def __init__(self, model):
super(OFABase, self).__init__()
self.model = model
......@@ -169,8 +174,7 @@ class OFA(OFABase):
)
### instance model by user can input super-param easily.
assert isinstance(self.distill_config.teacher_model,
paddle.fluid.dygraph.Layer)
assert isinstance(self.distill_config.teacher_model, Layer)
# load teacher parameter
if self.distill_config.teacher_model_path != None:
......@@ -190,9 +194,10 @@ class OFA(OFABase):
for name, sublayer in self.model.named_sublayers():
if name in mapping_layers:
netA = SuperConv2D(
sublayer._num_filters,
sublayer._num_filters,
filter_size=1)
getattr(sublayer, '_num_filters',
sublayer._out_channels),
getattr(sublayer, '_num_filters',
sublayer._out_channels), 1)
self.netAs_param.extend(netA.parameters())
self.netAs.append(netA)
......@@ -288,7 +293,8 @@ class OFA(OFABase):
n = self.distill_config.mapping_layers[i]
Tact = self.Tacts[n]
Sact = self.Sacts[n]
Sact = netA(Sact, channel=netA._num_filters)
Sact = netA(
Sact, channel=getattr(netA, '_num_filters', netA._out_channels))
if self.distill_config.distill_fn == None:
loss = fluid.layers.mse_loss(Sact, Tact)
else:
......
......@@ -44,3 +44,13 @@ def search_idx(num, sorted_nestlist):
return idx, phase_idx
assert num > max_num
return len(sorted_nestlist) - 1, max_idx
def get_paddle_version():
import paddle
pd_ver = 185
if hasattr(paddle, 'nn'):
if hasattr(paddle.nn, 'Conv1D'): ### judge 2.0 alpha
pd_ver = 200
return pd_ver
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
from paddle.vision.models import mobilenet_v1
from paddleslim.nas.ofa.convert_super import Convert, supernet
class TestConvertSuper(unittest.TestCase):
def setUp(self):
self.model = mobilenet_v1()
def test_convert(self):
sp_net_config = supernet(kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4])
sp_model = Convert(sp_net_config).convert(self.model)
assert len(sp_model.sublayers()) == 151
if __name__ == '__main__':
unittest.main()
......@@ -17,16 +17,15 @@ sys.path.append("../")
import numpy as np
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph.nn as nn
import paddle.nn as nn
from paddle.nn import ReLU
from paddleslim.nas import ofa
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
from paddleslim.nas.ofa.convert_super import supernet
from paddleslim.nas.ofa.layers import Block, SuperSeparableConv2D
from paddleslim.nas.ofa.layers_new import Block, SuperSeparableConv2D
class ModelConv(fluid.dygraph.Layer):
class ModelConv(nn.Layer):
def __init__(self):
super(ModelConv, self).__init__()
with supernet(
......@@ -35,16 +34,13 @@ class ModelConv(fluid.dygraph.Layer):
(8, 12, 16))) as ofa_super:
models = []
models += [nn.Conv2D(3, 4, 3, padding=1)]
models += [nn.InstanceNorm(4)]
models += [nn.InstanceNorm2D(4)]
models += [ReLU()]
models += [nn.Conv2D(4, 4, 3, groups=4)]
models += [nn.InstanceNorm(4)]
models += [nn.InstanceNorm2D(4)]
models += [ReLU()]
models += [
nn.Conv2DTranspose(
4, 4, 3, groups=4, padding=1, use_cudnn=True)
]
models += [nn.BatchNorm(4)]
models += [nn.Conv2DTranspose(4, 4, 3, groups=4, padding=1)]
models += [nn.BatchNorm2D(4)]
models += [ReLU()]
models += [nn.Conv2D(4, 3, 3)]
models += [ReLU()]
......@@ -60,21 +56,23 @@ class ModelConv(fluid.dygraph.Layer):
kernel_size=(3, 5, 7), expand_ratio=(1, 2, 4)) as ofa_super:
models1 = []
models1 += [nn.Conv2D(6, 4, 3)]
models1 += [nn.BatchNorm(4)]
models1 += [nn.BatchNorm2D(4)]
models1 += [ReLU()]
models1 += [nn.Conv2D(4, 4, 3, groups=2)]
models1 += [nn.InstanceNorm(4)]
models1 += [nn.InstanceNorm2D(4)]
models1 += [ReLU()]
models1 += [nn.Conv2DTranspose(4, 4, 3, groups=2)]
models1 += [nn.BatchNorm(4)]
models1 += [nn.BatchNorm2D(4)]
models1 += [ReLU()]
models1 += [nn.Conv2DTranspose(4, 4, 3)]
models1 += [nn.BatchNorm(4)]
models1 += [nn.BatchNorm2D(4)]
models1 += [ReLU()]
models1 += [nn.Conv2DTranspose(4, 4, 1)]
models1 += [nn.BatchNorm2D(4)]
models1 += [ReLU()]
models1 = ofa_super.convert(models1)
models += models1
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs, depth=None):
......@@ -89,16 +87,61 @@ class ModelConv(fluid.dygraph.Layer):
return inputs
class ModelLinear(fluid.dygraph.Layer):
class ModelConv2(nn.Layer):
def __init__(self):
super(ModelConv2, self).__init__()
with supernet(expand_ratio=(1, 2, 4)) as ofa_super:
models = []
models += [nn.Conv2DTranspose(4, 4, 3)]
models += [nn.BatchNorm2D(4)]
models += [ReLU()]
models += [nn.Conv2D(4, 4, 3)]
models += [nn.BatchNorm2D(4)]
models += [ReLU()]
models = ofa_super.convert(models)
with supernet(channel=((4, 6, 8), (4, 6, 8))) as ofa_super:
models1 = []
models1 += [nn.Conv2DTranspose(4, 4, 3)]
models1 += [nn.BatchNorm2D(4)]
models1 += [ReLU()]
models1 += [nn.Conv2DTranspose(4, 4, 3)]
models1 += [nn.BatchNorm2D(4)]
models1 += [ReLU()]
models1 = ofa_super.convert(models1)
models += models1
with supernet(kernel_size=(3, 5, 7)) as ofa_super:
models2 = []
models2 += [nn.Conv2D(4, 4, 3)]
models2 += [nn.BatchNorm2D(4)]
models2 += [ReLU()]
models2 += [nn.Conv2DTranspose(4, 4, 3)]
models2 += [nn.BatchNorm2D(4)]
models2 += [ReLU()]
models2 += [nn.Conv2D(4, 4, 3)]
models2 += [nn.BatchNorm2D(4)]
models2 += [ReLU()]
models2 = ofa_super.convert(models2)
models += models2
self.models = paddle.nn.Sequential(*models)
class ModelLinear(nn.Layer):
def __init__(self):
super(ModelLinear, self).__init__()
with supernet(expand_ratio=(1, 2, 4)) as ofa_super:
models = []
models += [nn.Embedding(num_embeddings=64, embedding_dim=64)]
models += [nn.Linear(64, 128)]
models += [nn.LayerNorm(128)]
models += [nn.Linear(128, 256)]
models = ofa_super.convert(models)
with supernet(expand_ratio=(1, 2, 4)) as ofa_super:
models1 = []
models1 += [nn.Embedding(size=(64, 64))]
models1 += [nn.Linear(64, 128)]
models1 += [nn.LayerNorm(128)]
models1 += [nn.Linear(128, 256)]
models1 += [nn.Linear(256, 256)]
models1 = ofa_super.convert(models1)
models += models1
......@@ -116,17 +159,21 @@ class ModelLinear(fluid.dygraph.Layer):
return inputs
class ModelLinear1(fluid.dygraph.Layer):
class ModelLinear1(nn.Layer):
def __init__(self):
super(ModelLinear1, self).__init__()
models = []
with supernet(channel=((64, 128, 256), (64, 128, 256),
(64, 128, 256))) as ofa_super:
models = []
models += [nn.Embedding(num_embeddings=64, embedding_dim=64)]
models += [nn.Linear(64, 128)]
models += [nn.LayerNorm(128)]
models += [nn.Linear(128, 256)]
models = ofa_super.convert(models)
with supernet(channel=((64, 128, 256), )) as ofa_super:
models1 = []
models1 += [nn.Embedding(size=(64, 64))]
models1 += [nn.Linear(64, 128)]
models1 += [nn.LayerNorm(128)]
models1 += [nn.Linear(128, 256)]
models1 += [nn.Linear(256, 256)]
models1 = ofa_super.convert(models1)
models += models1
......@@ -145,20 +192,16 @@ class ModelLinear1(fluid.dygraph.Layer):
return inputs
class ModelLinear2(fluid.dygraph.Layer):
class ModelLinear2(nn.Layer):
def __init__(self):
super(ModelLinear2, self).__init__()
models = []
with supernet(expand_ratio=None) as ofa_super:
models1 = []
models1 += [nn.Embedding(size=(64, 64))]
models1 += [nn.Linear(64, 128)]
models1 += [nn.LayerNorm(128)]
models1 += [nn.Linear(128, 256)]
models1 = ofa_super.convert(models1)
models += models1
models = []
models += [nn.Embedding(num_embeddings=64, embedding_dim=64)]
models += [nn.Linear(64, 128)]
models += [nn.LayerNorm(128)]
models += [nn.Linear(128, 256)]
models = ofa_super.convert(models)
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs, depth=None):
......@@ -175,7 +218,6 @@ class ModelLinear2(fluid.dygraph.Layer):
class TestOFA(unittest.TestCase):
def setUp(self):
fluid.enable_dygraph()
self.init_model_and_data()
self.init_config()
......@@ -185,7 +227,7 @@ class TestOFA(unittest.TestCase):
data_np = np.random.random((1, 3, 10, 10)).astype(np.float32)
label_np = np.random.random((1)).astype(np.float32)
self.data = fluid.dygraph.to_variable(data_np)
self.data = paddle.to_tensor(data_np)
def init_config(self):
default_run_config = {
......@@ -217,10 +259,9 @@ class TestOFA(unittest.TestCase):
cur_idx = self.run_config.n_epochs[idx]
for ph_idx in range(len(cur_idx)):
cur_lr = self.run_config.init_learning_rate[idx][ph_idx]
adam = fluid.optimizer.Adam(
adam = paddle.optimizer.Adam(
learning_rate=cur_lr,
parameter_list=(
ofa_model.parameters() + ofa_model.netAs_param))
parameters=(ofa_model.parameters() + ofa_model.netAs_param))
for epoch_id in range(start_epoch,
self.run_config.n_epochs[idx][ph_idx]):
if epoch_id == 0:
......@@ -228,7 +269,7 @@ class TestOFA(unittest.TestCase):
for model_no in range(self.run_config.dynamic_batch_size[
idx]):
output, _ = ofa_model(self.data)
loss = fluid.layers.reduce_mean(output)
loss = paddle.mean(output)
if self.distill_config.mapping_layers != None:
dis_loss = ofa_model.calc_distill_loss()
loss += dis_loss
......@@ -249,7 +290,7 @@ class TestOFACase1(TestOFA):
self.teacher_model = ModelLinear()
data_np = np.random.random((3, 64)).astype(np.int64)
self.data = fluid.dygraph.to_variable(data_np)
self.data = paddle.to_tensor(data_np)
def init_config(self):
default_run_config = {
......@@ -275,7 +316,7 @@ class TestOFACase2(TestOFACase1):
self.teacher_model = ModelLinear1()
data_np = np.random.random((3, 64)).astype(np.int64)
self.data = fluid.dygraph.to_variable(data_np)
self.data = paddle.to_tensor(data_np)
class TestOFACase3(unittest.TestCase):
......@@ -285,5 +326,10 @@ class TestOFACase3(unittest.TestCase):
ofa_model.set_net_config({'expand_ratio': None})
class TestOFACase3(unittest.TestCase):
def test_ofa(self):
self.model = ModelConv2()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册