未验证 提交 b09698bb 编写于 作者: C ceci3 提交者: GitHub

Add ofa (#416)

* add ofa
上级 8854b71d
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph.nn as nn
from paddle.nn import ReLU
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
from paddleslim.nas.ofa import supernet
class Model(fluid.dygraph.Layer):
def __init__(self):
super(Model, self).__init__()
with supernet(
kernel_size=(3, 5, 7), expand_ratio=[1, 2, 4]) as ofa_super:
models = []
models += [nn.Conv2D(1, 6, 3)]
models += [ReLU()]
models += [nn.Pool2D(2, 'max', 2)]
models += [nn.Conv2D(6, 16, 5, padding=0)]
models += [ReLU()]
models += [nn.Pool2D(2, 'max', 2)]
models += [
nn.Linear(784, 120), nn.Linear(120, 84), nn.Linear(84, 10)
]
models = ofa_super.convert(models)
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs, label, depth=None):
if depth != None:
assert isinstance(depth, int)
assert depth < len(self.models)
models = self.models[:depth]
else:
depth = len(self.models)
models = self.models[:]
for idx, layer in enumerate(models):
if idx == 6:
inputs = fluid.layers.flatten(inputs, 1)
inputs = layer(inputs)
inputs = fluid.layers.softmax(inputs)
return inputs
def test_ofa():
default_run_config = {
'train_batch_size': 256,
'eval_batch_size': 64,
'n_epochs': [[1], [2, 3], [4, 5]],
'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]],
'dynamic_batch_size': [1, 1, 1],
'total_images': 50000, #1281167,
'elastic_depth': (2, 5, 8)
}
run_config = RunConfig(**default_run_config)
default_distill_config = {
'lambda_distill': 0.01,
'teacher_model': Model,
'mapping_layers': ['models.0.fn']
}
distill_config = DistillConfig(**default_distill_config)
fluid.enable_dygraph()
model = Model()
ofa_model = OFA(model, run_config, distill_config=distill_config)
train_reader = paddle.fluid.io.batch(
paddle.dataset.mnist.train(), batch_size=256, drop_last=True)
start_epoch = 0
for idx in range(len(run_config.n_epochs)):
cur_idx = run_config.n_epochs[idx]
for ph_idx in range(len(cur_idx)):
cur_lr = run_config.init_learning_rate[idx][ph_idx]
adam = fluid.optimizer.Adam(
learning_rate=cur_lr,
parameter_list=(ofa_model.parameters() + ofa_model.netAs_param))
for epoch_id in range(start_epoch,
run_config.n_epochs[idx][ph_idx]):
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = fluid.dygraph.to_variable(dy_x_data)
label = fluid.dygraph.to_variable(y_data)
label.stop_gradient = True
for model_no in range(run_config.dynamic_batch_size[idx]):
output, _ = ofa_model(img, label)
loss = fluid.layers.reduce_mean(output)
dis_loss = ofa_model.calc_distill_loss()
loss += dis_loss
loss.backward()
if batch_id % 10 == 0:
print(
'epoch: {}, batch: {}, loss: {}, distill loss: {}'.
format(epoch_id, batch_id,
loss.numpy()[0], dis_loss.numpy()[0]))
### accumurate dynamic_batch_size network of gradients for same batch of data
### NOTE: need to fix gradients accumulate in PaddlePaddle
adam.minimize(loss)
adam.clear_gradients()
start_epoch = run_config.n_epochs[idx][ph_idx]
test_ofa()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .ofa import OFA, RunConfig, DistillConfig
from .convert_super import supernet
from .layers import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import decorator
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid import framework
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, BatchNorm, InstanceNorm
from .layers import *
from ...common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
__all__ = ['supernet']
WEIGHT_LAYER = ['conv', 'linear']
### TODO: add decorator
class Convert:
def __init__(self, context):
self.context = context
def convert(self, model):
# search the first and last weight layer, don't change out channel of the last weight layer
# don't change in channel of the first weight layer
first_weight_layer_idx = -1
last_weight_layer_idx = -1
weight_layer_count = 0
# NOTE: pre_channel store for shortcut module
pre_channel = 0
cur_channel = None
for idx, layer in enumerate(model):
cls_name = layer.__class__.__name__.lower()
if 'conv' in cls_name or 'linear' in cls_name:
weight_layer_count += 1
last_weight_layer_idx = idx
if first_weight_layer_idx == -1:
first_weight_layer_idx = idx
if getattr(self.context, 'channel', None) != None:
assert len(
self.context.channel
) == weight_layer_count, "length of channel must same as weight layer."
for idx, layer in enumerate(model):
if isinstance(layer, Conv2D):
attr_dict = layer.__dict__
key = attr_dict['_full_name']
new_attr_name = [
'_stride', '_dilation', '_groups', '_param_attr',
'_bias_attr', '_use_cudnn', '_act', '_dtype'
]
new_attr_dict = dict()
new_attr_dict['candidate_config'] = dict()
self.kernel_size = getattr(self.context, 'kernel_size', None)
if self.kernel_size != None:
new_attr_dict['transform_kernel'] = True
# if the kernel_size of conv is 1, don't change it.
#if self.kernel_size and int(attr_dict['_filter_size'][0]) != 1:
if self.kernel_size and int(attr_dict['_filter_size']) != 1:
new_attr_dict['filter_size'] = max(self.kernel_size)
new_attr_dict['candidate_config'].update({
'kernel_size': self.kernel_size
})
else:
new_attr_dict['filter_size'] = attr_dict['_filter_size']
if self.context.expand:
### first super convolution
if idx == first_weight_layer_idx:
new_attr_dict['num_channels'] = attr_dict[
'_num_channels']
else:
new_attr_dict[
'num_channels'] = self.context.expand * attr_dict[
'_num_channels']
### last super convolution
if idx == last_weight_layer_idx:
new_attr_dict['num_filters'] = attr_dict['_num_filters']
else:
new_attr_dict[
'num_filters'] = self.context.expand * attr_dict[
'_num_filters']
new_attr_dict['candidate_config'].update({
'expand_ratio': self.context.expand_ratio
})
elif self.context.channel:
if attr_dict['_groups'] != None and (
int(attr_dict['_groups']) ==
int(attr_dict['_num_channels'])):
### depthwise conv, if conv is depthwise, use pre channel as cur_channel
_logger.warn(
"If convolution is a depthwise conv, output channel change" \
" to the same channel with input, output channel in search is not used."
)
cur_channel = pre_channel
else:
cur_channel = self.context.channel[0]
self.context.channel = self.context.channel[1:]
if idx == first_weight_layer_idx:
new_attr_dict['num_channels'] = attr_dict[
'_num_channels']
else:
new_attr_dict['num_channels'] = max(pre_channel)
if idx == last_weight_layer_idx:
new_attr_dict['num_filters'] = attr_dict['_num_filters']
else:
new_attr_dict['num_filters'] = max(cur_channel)
new_attr_dict['candidate_config'].update({
'channel': cur_channel
})
pre_channel = cur_channel
else:
new_attr_dict['num_filters'] = attr_dict['_num_filters']
new_attr_dict['num_channels'] = attr_dict['_num_channels']
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
del layer
if attr_dict['_groups'] == None or int(attr_dict[
'_groups']) == 1:
### standard conv
layer = Block(SuperConv2D(**new_attr_dict), key=key)
elif int(attr_dict['_groups']) == int(attr_dict[
'_num_channels']):
# if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
# channel in candidate_config = in_channel_list
if 'channel' in new_attr_dict['candidate_config']:
new_attr_dict['num_channels'] = max(cur_channel)
new_attr_dict['num_filters'] = new_attr_dict[
'num_channels']
new_attr_dict['candidate_config'][
'channel'] = cur_channel
new_attr_dict['groups'] = new_attr_dict['num_channels']
layer = Block(
SuperDepthwiseConv2D(**new_attr_dict), key=key)
else:
### group conv
layer = Block(SuperGroupConv2D(**new_attr_dict), key=key)
model[idx] = layer
elif isinstance(layer, BatchNorm) and (
getattr(self.context, 'expand', None) != None or
getattr(self.context, 'channel', None) != None):
# num_features in BatchNorm don't change after last weight operators
if idx > last_weight_layer_idx:
continue
attr_dict = layer.__dict__
new_attr_name = [
'_param_attr', '_bias_attr', '_act', '_dtype', '_in_place',
'_data_layout', '_momentum', '_epsilon', '_is_test',
'_use_global_stats', '_trainable_statistics'
]
new_attr_dict = dict()
if self.context.expand:
new_attr_dict['num_channels'] = self.context.expand * int(
layer._parameters['weight'].shape[0])
elif self.context.channel:
new_attr_dict['num_channels'] = max(cur_channel)
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
del layer, attr_dict
layer = SuperBatchNorm(**new_attr_dict)
model[idx] = layer
### assume output_size = None, filter_size != None
### NOTE: output_size != None may raise error, solve when it happend.
elif isinstance(layer, Conv2DTranspose):
attr_dict = layer.__dict__
key = attr_dict['_full_name']
new_attr_name = [
'_stride', '_dilation', '_groups', '_param_attr',
'_bias_attr', '_use_cudnn', '_act', '_dtype', '_output_size'
]
assert attr_dict[
'_filter_size'] != None, "Conv2DTranspose only support filter size != None now"
new_attr_dict = dict()
new_attr_dict['candidate_config'] = dict()
self.kernel_size = getattr(self.context, 'kernel_size', None)
if self.kernel_size != None:
new_attr_dict['transform_kernel'] = True
# if the kernel_size of conv transpose is 1, don't change it.
if self.kernel_size and int(attr_dict['_filter_size'][0]) != 1:
new_attr_dict['filter_size'] = max(self.kernel_size)
new_attr_dict['candidate_config'].update({
'kernel_size': self.kernel_size
})
else:
new_attr_dict['filter_size'] = attr_dict['_filter_size']
if self.context.expand:
### first super convolution transpose
if idx == first_weight_layer_idx:
new_attr_dict['num_channels'] = attr_dict[
'_num_channels']
else:
new_attr_dict[
'num_channels'] = self.context.expand * attr_dict[
'_num_channels']
### last super convolution transpose
if idx == last_weight_layer_idx:
new_attr_dict['num_filters'] = attr_dict['_num_filters']
else:
new_attr_dict[
'num_filters'] = self.context.expand * attr_dict[
'_num_filters']
new_attr_dict['candidate_config'].update({
'expand_ratio': self.context.expand_ratio
})
elif self.context.channel:
if attr_dict['_groups'] != None and (
int(attr_dict['_groups']) ==
int(attr_dict['_num_channels'])):
### depthwise conv_transpose
_logger.warn(
"If convolution is a depthwise conv_transpose, output channel " \
"change to the same channel with input, output channel in search is not used."
)
cur_channel = pre_channel
else:
cur_channel = self.context.channel[0]
self.context.channel = self.context.channel[1:]
if idx == first_weight_layer_idx:
new_attr_dict['num_channels'] = attr_dict[
'_num_channels']
else:
new_attr_dict['num_channels'] = max(pre_channel)
if idx == last_weight_layer_idx:
new_attr_dict['num_filters'] = attr_dict['_num_filters']
else:
new_attr_dict['num_filters'] = max(cur_channel)
new_attr_dict['candidate_config'].update({
'channel': cur_channel
})
pre_channel = cur_channel
else:
new_attr_dict['num_filters'] = attr_dict['_num_filters']
new_attr_dict['num_channels'] = attr_dict['_num_channels']
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
del layer
if new_attr_dict['output_size'] == []:
new_attr_dict['output_size'] = None
if attr_dict['_groups'] == None or int(attr_dict[
'_groups']) == 1:
### standard conv_transpose
layer = Block(
SuperConv2DTranspose(**new_attr_dict), key=key)
elif int(attr_dict['_groups']) == int(attr_dict[
'_num_channels']):
# if conv is depthwise conv, groups = in_channel, out_channel = in_channel,
# channel in candidate_config = in_channel_list
if 'channel' in new_attr_dict['candidate_config']:
new_attr_dict['num_channels'] = max(cur_channel)
new_attr_dict['num_filters'] = new_attr_dict[
'num_channels']
new_attr_dict['candidate_config'][
'channel'] = cur_channel
new_attr_dict['groups'] = new_attr_dict['num_channels']
layer = Block(
SuperDepthwiseConv2DTranspose(**new_attr_dict), key=key)
else:
### group conv_transpose
layer = Block(
SuperGroupConv2DTranspose(**new_attr_dict), key=key)
model[idx] = layer
elif isinstance(layer, Linear) and (
getattr(self.context, 'expand', None) != None or
getattr(self.context, 'channel', None) != None):
attr_dict = layer.__dict__
key = attr_dict['_full_name']
### TODO(paddle): add _param_attr and _bias_attr as private variable of Linear
#new_attr_name = ['_act', '_dtype', '_param_attr', '_bias_attr']
new_attr_name = ['_act', '_dtype']
in_nc, out_nc = layer._parameters['weight'].shape
new_attr_dict = dict()
new_attr_dict['candidate_config'] = dict()
if self.context.expand:
if idx == first_weight_layer_idx:
new_attr_dict['input_dim'] = int(in_nc)
else:
new_attr_dict['input_dim'] = self.context.expand * int(
in_nc)
if idx == last_weight_layer_idx:
new_attr_dict['output_dim'] = int(out_nc)
else:
new_attr_dict['output_dim'] = self.context.expand * int(
out_nc)
new_attr_dict['candidate_config'].update({
'expand_ratio': self.context.expand_ratio
})
elif self.context.channel:
cur_channel = self.context.channel[0]
self.context.channel = self.context.channel[1:]
if idx == first_weight_layer_idx:
new_attr_dict['input_dim'] = int(in_nc)
else:
new_attr_dict['input_dim'] = max(pre_channel)
if idx == last_weight_layer_idx:
new_attr_dict['output_dim'] = int(out_nc)
else:
new_attr_dict['output_dim'] = max(cur_channel)
new_attr_dict['candidate_config'].update({
'channel': cur_channel
})
pre_channel = cur_channel
else:
new_attr_dict['input_dim'] = int(in_nc)
new_attr_dict['output_dim'] = int(out_nc)
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
del layer, attr_dict
layer = Block(SuperLinear(**new_attr_dict), key=key)
model[idx] = layer
elif isinstance(layer, InstanceNorm) and (
getattr(self.context, 'expand', None) != None or
getattr(self.context, 'channel', None) != None):
# num_features in InstanceNorm don't change after last weight operators
if idx > last_weight_layer_idx:
continue
attr_dict = layer.__dict__
new_attr_name = [
'_param_attr', '_bias_attr', '_dtype', '_epsilon'
]
new_attr_dict = dict()
if self.context.expand:
new_attr_dict['num_channels'] = self.context.expand * int(
layer._parameters['scale'].shape[0])
elif self.context.channel:
new_attr_dict['num_channels'] = max(cur_channel)
for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
del layer, attr_dict
layer = SuperInstanceNorm(**new_attr_dict)
model[idx] = layer
return model
class supernet:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
assert (
getattr(self, 'expand_ratio', None) == None or
getattr(self, 'channel', None) == None
), "expand_ratio and channel CANNOT be NOT None at the same time."
self.expand = None
if 'expand_ratio' in kwargs.keys():
if isinstance(self.expand_ratio, list) or isinstance(
self.expand_ratio, tuple):
self.expand = max(self.expand_ratio)
elif isinstance(self.expand_ratio, int):
self.expand = self.expand_ratio
def __enter__(self):
return Convert(self)
def __exit__(self, exc_type, exc_val, exc_tb):
pass
#def ofa_supernet(kernel_size, expand_ratio):
# def _ofa_supernet(func):
# @functools.wraps(func)
# def convert(*args, **kwargs):
# supernet_convert(*args, **kwargs)
# return convert
# return _ofa_supernet
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import logging
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.dygraph_utils as dygraph_utils
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import in_dygraph_mode, _varbase_creator
from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose, BatchNorm
from ...common import get_logger
from .utils.utils import compute_start_end, get_same_padding, convert_to_list
__all__ = [
'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D',
'SuperBatchNorm', 'SuperLinear', 'SuperInstanceNorm', 'Block',
'SuperGroupConv2D', 'SuperDepthwiseConv2D', 'SuperGroupConv2DTranspose',
'SuperDepthwiseConv2DTranspose'
]
_logger = get_logger(__name__, level=logging.INFO)
### TODO: if task is elastic width, need to add re_organize_middle_weight in 1x1 conv in MBBlock
_cnt = 0
def counter():
global _cnt
_cnt += 1
return _cnt
class BaseBlock(fluid.dygraph.Layer):
def __init__(self, key=None):
super(BaseBlock, self).__init__()
if key is not None:
self._key = str(key)
else:
self._key = self.__class__.__name__ + str(counter())
# set SuperNet class
def set_supernet(self, supernet):
self.__dict__['supernet'] = supernet
@property
def key(self):
return self._key
class Block(BaseBlock):
"""
Model is composed of nest blocks.
Parameters:
fn(Layer): instance of super layers, such as: SuperConv2D(3, 5, 3).
key(str, optional): key of this layer, one-to-one correspondence between key and candidate config. Default: None.
"""
def __init__(self, fn, key=None):
super(Block, self).__init__(key)
self.fn = fn
self.candidate_config = self.fn.candidate_config
def forward(self, *inputs, **kwargs):
out = self.supernet.layers_forward(self, *inputs, **kwargs)
return out
class SuperConv2D(fluid.dygraph.Conv2D):
"""
This interface is used to construct a callable object of the ``SuperConv2D`` class.
The difference between ```SuperConv2D``` and ```Conv2D``` is: ```SuperConv2D``` need
to feed a config dictionary with the format of {'channel', num_of_channel} represents
the channels of the outputs, used to change the first dimension of weight and bias,
only train the first channels of the weight and bias.
Note: the channel in config need to less than first defined.
The super convolution2D layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input and
Output are in NCHW format, where N is batch size, C is the number of
the feature map, H is the height of the feature map, and W is the width of the feature map.
Filter's shape is [MCHW] , where M is the number of output feature map,
C is the number of input feature map, H is the height of the filter,
and W is the width of the filter. If the groups is greater than 1,
C will equal the number of input feature map divided by the groups.
Please refer to UFLDL's `convolution
<http://ufldl.stanford.edu/tutorial/supervised/FeatureExtractionUsingConvolution/>`_
for more details.
If bias attribution and activation type are provided, bias is added to the
output of the convolution, and the corresponding activation function is
applied to the final result.
For each input :math:`X`, the equation is:
.. math::
Out = \\sigma (W \\ast X + b)
Where:
* :math:`X`: Input value, a ``Tensor`` with NCHW format.
* :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] .
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1].
* :math:`\\sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\
W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1
Parameters:
num_channels(int): The number of channels in the input image.
num_filters(int): The number of filter. It is as same as the output
feature map.
filter_size (int or tuple): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
candidate_config(dict, optional): Dictionary descripts candidate config of this layer,
such as {'kernel_size': (3, 5, 7), 'channel': (4, 6, 8)}, means the kernel size of
this layer can be choose from (3, 5, 7), the key of candidate_config
only can be 'kernel_size', 'channel' and 'expand_ratio', 'channel' and 'expand_ratio'
CANNOT be set at the same time. Default: None.
transform_kernel(bool, optional): Whether to use transform matrix to transform a large filter
to a small filter. Default: False.
stride (int or tuple, optional): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: 1.
padding (int or tuple, optional): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: 0.
dilation (int or tuple, optional): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: 1.
groups (int, optional): The groups number of the Conv2d Layer. According to grouped
convolution in Alex Krizhevsky's Deep CNN paper: when group=2,
the first half of the filters is only connected to the first half
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. Default: 1.
param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter)
of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with :math:`Normal(0.0, std)`,
and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None.
bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn (bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True.
act (str, optional): Activation type, if it is set to None, activation is not appended.
Default: None.
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
Attribute:
**weight** (Parameter): the learnable weights of filter of this layer.
**bias** (Parameter or None): the learnable bias of this layer.
Returns:
None
Raises:
ValueError: if ``use_cudnn`` is not a bool value.
Examples:
.. code-block:: python
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddleslim.core.layers import SuperConv2D
import numpy as np
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
with fluid.dygraph.guard():
super_conv2d = SuperConv2D(3, 10, 3)
config = {'channel': 5}
data = to_variable(data)
conv = super_conv2d(data, config)
"""
### NOTE: filter_size, num_channels and num_filters must be the max of candidate to define a largest network.
def __init__(self,
num_channels,
num_filters,
filter_size,
candidate_config={},
transform_kernel=False,
stride=1,
dilation=1,
padding=0,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype='float32'):
### NOTE: padding always is 0, add padding in forward because of kernel size is uncertain
### TODO: change padding to any padding
super(SuperConv2D, self).__init__(
num_channels, num_filters, filter_size, stride, padding, dilation,
groups, param_attr, bias_attr, use_cudnn, act, dtype)
if isinstance(self._filter_size, int):
self._filter_size = convert_to_list(self._filter_size, 2)
self.candidate_config = candidate_config
if len(candidate_config.items()) != 0:
for k, v in candidate_config.items():
candidate_config[k] = list(set(v))
self.ks_set = candidate_config[
'kernel_size'] if 'kernel_size' in candidate_config else None
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.channel = candidate_config[
'channel'] if 'channel' in candidate_config else None
self.base_channel = None
if self.expand_ratio != None:
self.base_channel = int(self._num_filters / max(self.expand_ratio))
self.transform_kernel = transform_kernel
if self.ks_set != None:
self.ks_set.sort()
if self.transform_kernel != False:
scale_param = dict()
### create parameter to transform kernel
for i in range(len(self.ks_set) - 1):
ks_small = self.ks_set[i]
ks_large = self.ks_set[i + 1]
param_name = '%dto%d_matrix' % (ks_large, ks_small)
ks_t = ks_small**2
scale_param[param_name] = self.create_parameter(
attr=fluid.ParamAttr(
name=self._full_name + param_name,
initializer=fluid.initializer.NumpyArrayInitializer(
np.eye(ks_t))),
shape=(ks_t, ks_t),
dtype=self._dtype)
for name, param in scale_param.items():
setattr(self, name, param)
def get_active_filter(self, in_nc, out_nc, kernel_size):
start, end = compute_start_end(self._filter_size[0], kernel_size)
### if NOT transform kernel, intercept a center filter with kernel_size from largest filter
filters = self.weight[:out_nc, :in_nc, start:end, start:end]
if self.transform_kernel != False and kernel_size < self._filter_size[
0]:
### if transform kernel, then use matrix to transform
start_filter = self.weight[:out_nc, :in_nc, :, :]
for i in range(len(self.ks_set) - 1, 0, -1):
src_ks = self.ks_set[i]
if src_ks <= kernel_size:
break
target_ks = self.ks_set[i - 1]
start, end = compute_start_end(src_ks, target_ks)
_input_filter = start_filter[:, :, start:end, start:end]
_input_filter = fluid.layers.reshape(
_input_filter,
shape=[(_input_filter.shape[0] * _input_filter.shape[1]),
-1])
core.ops.matmul(_input_filter,
self.__getattr__('%dto%d_matrix' %
(src_ks, target_ks)),
_input_filter, 'transpose_X', False,
'transpose_Y', False, "alpha", 1)
_input_filter = fluid.layers.reshape(
_input_filter,
shape=[
filters.shape[0], filters.shape[1], target_ks, target_ks
])
start_filter = _input_filter
filters = start_filter
return filters
def get_groups_in_out_nc(self, in_nc, out_nc):
### standard conv
return self._groups, in_nc, out_nc
def forward(self, input, kernel_size=None, expand_ratio=None, channel=None):
if not in_dygraph_mode():
_logger.error("NOT support static graph")
in_nc = int(input.shape[1])
assert (
expand_ratio == None or channel == None
), "expand_ratio and channel CANNOT be NOT None at the same time."
if expand_ratio != None:
out_nc = int(expand_ratio * self.base_channel)
elif channel != None:
out_nc = int(channel)
else:
out_nc = self._num_filters
ks = int(self._filter_size[0]) if kernel_size == None else int(
kernel_size)
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
out_nc)
weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks)
padding = convert_to_list(get_same_padding(ks), 2)
if self._l_type == 'conv2d':
attrs = ('strides', self._stride, 'paddings', padding, 'dilations',
self._dilation, 'groups', groups
if groups else 1, 'use_cudnn', self._use_cudnn)
out = core.ops.conv2d(input, weight, *attrs)
elif self._l_type == 'depthwise_conv2d':
attrs = ('strides', self._stride, 'paddings', padding, 'dilations',
self._dilation, 'groups', groups
if groups else self._groups, 'use_cudnn', self._use_cudnn)
out = core.ops.depthwise_conv2d(input, weight, *attrs)
else:
raise ValueError("conv type error")
pre_bias = out
out_nc = int(pre_bias.shape[1])
if self.bias is not None:
bias = self.bias[:out_nc]
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, 1)
else:
pre_act = pre_bias
return dygraph_utils._append_activation_in_dygraph(pre_act, self._act)
class SuperGroupConv2D(SuperConv2D):
def get_groups_in_out_nc(self, in_nc, out_nc):
### groups convolution
### conv: weight: (Cout, Cin/G, Kh, Kw)
groups = self._groups
in_nc = int(in_nc // groups)
return groups, in_nc, out_nc
class SuperDepthwiseConv2D(SuperConv2D):
### depthwise convolution
def get_groups_in_out_nc(self, in_nc, out_nc):
if in_nc != out_nc:
_logger.debug(
"input channel and output channel in depthwise conv is different, change output channel to input channel! origin channel:(in_nc {}, out_nc {}): ".
format(in_nc, out_nc))
groups = in_nc
out_nc = in_nc
return groups, in_nc, out_nc
class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose):
"""
This interface is used to construct a callable object of the ``SuperConv2DTranspose``
class.
The difference between ```SuperConv2DTranspose``` and ```Conv2DTranspose``` is:
```SuperConv2DTranspose``` need to feed a config dictionary with the format of
{'channel', num_of_channel} represents the channels of the outputs, used to change
the first dimension of weight and bias, only train the first channels of the weight
and bias.
Note: the channel in config need to less than first defined.
The super convolution2D transpose layer calculates the output based on the input,
filter, and dilations, strides, paddings. Input and output
are in NCHW format. Where N is batch size, C is the number of feature map,
H is the height of the feature map, and W is the width of the feature map.
Filter's shape is [MCHW] , where M is the number of input feature map,
C is the number of output feature map, H is the height of the filter,
and W is the width of the filter. If the groups is greater than 1,
C will equal the number of input feature map divided by the groups.
If bias attribution and activation type are provided, bias is added to
the output of the convolution, and the corresponding activation function
is applied to the final result.
The details of convolution transpose layer, please refer to the following explanation and references
`conv2dtranspose <http://www.matthewzeiler.com/wp-content/uploads/2017/07/cvpr2010.pdf>`_ .
For each input :math:`X`, the equation is:
.. math::
Out = \sigma (W \\ast X + b)
Where:
* :math:`X`: Input value, a ``Tensor`` with NCHW format.
* :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] .
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1].
* :math:`\\sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
- Input:
Input shape: :math:`(N, C_{in}, H_{in}, W_{in})`
Filter shape: :math:`(C_{in}, C_{out}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
.. math::
H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\
W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 \\\\
H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] ) \\\\
W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] )
Parameters:
num_channels(int): The number of channels in the input image.
num_filters(int): The number of the filter. It is as same as the output
feature map.
filter_size(int or tuple): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
candidate_config(dict, optional): Dictionary descripts candidate config of this layer,
such as {'kernel_size': (3, 5, 7), 'channel': (4, 6, 8)}, means the kernel size of
this layer can be choose from (3, 5, 7), the key of candidate_config
only can be 'kernel_size', 'channel' and 'expand_ratio', 'channel' and 'expand_ratio'
CANNOT be set at the same time. Default: None.
transform_kernel(bool, optional): Whether to use transform matrix to transform a large filter
to a small filter. Default: False.
output_size(int or tuple, optional): The output image size. If output size is a
tuple, it must contain two integers, (image_H, image_W). None if use
filter_size, padding, and stride to calculate output_size.
if output_size and filter_size are specified at the same time, They
should follow the formula above. Default: None.
padding(int or tuple, optional): The padding size. If padding is a tuple, it must
contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: 0.
stride(int or tuple, optional): The stride size. If stride is a tuple, it must
contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: 1.
dilation(int or tuple, optional): The dilation size. If dilation is a tuple, it must
contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: 1.
groups(int, optional): The groups number of the Conv2d transpose layer. Inspired by
grouped convolution in Alex Krizhevsky's Deep CNN paper, in which
when group=2, the first half of the filters is only connected to the
first half of the input channels, while the second half of the
filters is only connected to the second half of the input channels.
Default: 1.
param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter)
of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d_transpose.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, conv2d_transpose
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True.
act (str, optional): Activation type, if it is set to None, activation is not appended.
Default: None.
dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32".
Attribute:
**weight** (Parameter): the learnable weights of filters of this layer.
**bias** (Parameter or None): the learnable bias of this layer.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
from paddleslim.core.layers import SuperConv2DTranspose
import numpy as np
with fluid.dygraph.guard():
data = np.random.random((3, 32, 32, 5)).astype('float32')
config = {'channel': 5
super_convtranspose = SuperConv2DTranspose(num_channels=32, num_filters=10, filter_size=3)
ret = super_convtranspose(fluid.dygraph.base.to_variable(data), config)
"""
def __init__(self,
num_channels,
num_filters,
filter_size,
output_size=None,
candidate_config={},
transform_kernel=False,
stride=1,
dilation=1,
padding=0,
groups=None,
param_attr=None,
bias_attr=None,
use_cudnn=True,
act=None,
dtype='float32'):
### NOTE: padding always is 0, add padding in forward because of kernel size is uncertain
super(SuperConv2DTranspose, self).__init__(
num_channels, num_filters, filter_size, output_size, padding,
stride, dilation, groups, param_attr, bias_attr, use_cudnn, act,
dtype)
self.candidate_config = candidate_config
if len(self.candidate_config.items()) != 0:
for k, v in candidate_config.items():
candidate_config[k] = list(set(v))
self.ks_set = candidate_config[
'kernel_size'] if 'kernel_size' in candidate_config else None
if isinstance(self._filter_size, int):
self._filter_size = convert_to_list(self._filter_size, 2)
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.channel = candidate_config[
'channel'] if 'channel' in candidate_config else None
self.base_channel = None
if self.expand_ratio:
self.base_channel = int(self._num_filters / max(self.expand_ratio))
self.transform_kernel = transform_kernel
if self.ks_set != None:
self.ks_set.sort()
if self.transform_kernel != False:
scale_param = dict()
### create parameter to transform kernel
for i in range(len(self.ks_set) - 1):
ks_small = self.ks_set[i]
ks_large = self.ks_set[i + 1]
param_name = '%dto%d_matrix' % (ks_large, ks_small)
ks_t = ks_small**2
scale_param[param_name] = self.create_parameter(
attr=fluid.ParamAttr(
name=self._full_name + param_name,
initializer=fluid.initializer.NumpyArrayInitializer(
np.eye(ks_t))),
shape=(ks_t, ks_t),
dtype=self._dtype)
for name, param in scale_param.items():
setattr(self, name, param)
def get_active_filter(self, in_nc, out_nc, kernel_size):
start, end = compute_start_end(self._filter_size[0], kernel_size)
filters = self.weight[:in_nc, :out_nc, start:end, start:end]
if self.transform_kernel != False and kernel_size < self._filter_size[
0]:
start_filter = self.weight[:in_nc, :out_nc, :, :]
for i in range(len(self.ks_set) - 1, 0, -1):
src_ks = self.ks_set[i]
if src_ks <= kernel_size:
break
target_ks = self.ks_set[i - 1]
start, end = compute_start_end(src_ks, target_ks)
_input_filter = start_filter[:, :, start:end, start:end]
_input_filter = fluid.layers.reshape(
_input_filter,
shape=[(_input_filter.shape[0] * _input_filter.shape[1]),
-1])
core.ops.matmul(_input_filter,
self.__getattr__('%dto%d_matrix' %
(src_ks, target_ks)),
_input_filter, 'transpose_X', False,
'transpose_Y', False, "alpha", 1)
_input_filter = fluid.layers.reshape(
_input_filter,
shape=[
filters.shape[0], filters.shape[1], target_ks, target_ks
])
start_filter = _input_filter
filters = start_filter
return filters
def get_groups_in_out_nc(self, in_nc, out_nc):
### standard conv
return self._groups, in_nc, out_nc
def forward(self, input, kernel_size=None, expand_ratio=None, channel=None):
if not in_dygraph_mode():
_logger.error("NOT support static graph")
in_nc = int(input.shape[1])
assert (
expand_ratio == None or channel == None
), "expand_ratio and channel CANNOT be NOT None at the same time."
if expand_ratio != None:
out_nc = int(expand_ratio * self.base_channel)
elif channel != None:
out_nc = int(channel)
else:
out_nc = self._num_filters
ks = int(self._filter_size[0]) if kernel_size == None else int(
kernel_size)
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
out_nc)
weight = self.get_active_filter(weight_in_nc, weight_out_nc, ks)
padding = convert_to_list(get_same_padding(ks), 2)
op = getattr(core.ops, self._op_type)
out = op(input, weight, 'output_size', self._output_size, 'strides',
self._stride, 'paddings', padding, 'dilations', self._dilation,
'groups', groups, 'use_cudnn', self._use_cudnn)
pre_bias = out
out_nc = int(pre_bias.shape[1])
if self.bias is not None:
bias = self.bias[:out_nc]
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, 1)
else:
pre_act = pre_bias
return dygraph_utils._append_activation_in_dygraph(
pre_act, act=self._act)
class SuperGroupConv2DTranspose(SuperConv2DTranspose):
def get_groups_in_out_nc(self, in_nc, out_nc):
### groups convolution
### groups conv transpose: weight: (Cin, Cout/G, Kh, Kw)
groups = self._groups
out_nc = int(out_nc // groups)
return groups, in_nc, out_nc
class SuperDepthwiseConv2DTranspose(SuperConv2DTranspose):
def get_groups_in_out_nc(self, in_nc, out_nc):
if in_nc != out_nc:
_logger.debug(
"input channel and output channel in depthwise conv transpose is different, change output channel to input channel! origin channel:(in_nc {}, out_nc {}): ".
format(in_nc, out_nc))
groups = in_nc
out_nc = in_nc
return groups, in_nc, out_nc
### NOTE: only search channel, write for GAN-compression, maybe change to SuperDepthwiseConv and SuperConv after.
class SuperSeparableConv2D(fluid.dygraph.Layer):
"""
This interface is used to construct a callable object of the ``SuperSeparableConv2D``
class.
The difference between ```SuperSeparableConv2D``` and ```SeparableConv2D``` is:
```SuperSeparableConv2D``` need to feed a config dictionary with the format of
{'channel', num_of_channel} represents the channels of the first conv's outputs and
the second conv's inputs, used to change the first dimension of weight and bias,
only train the first channels of the weight and bias.
The architecture of super separable convolution2D op is [Conv2D, norm layer(may be BatchNorm
or InstanceNorm), Conv2D]. The first conv is depthwise conv, the filter number is input channel
multiply scale_factor, the group is equal to the number of input channel. The second conv
is standard conv, which filter size and stride size are 1.
Parameters:
num_channels(int): The number of channels in the input image.
num_filters(int): The number of the second conv's filter. It is as same as the output
feature map.
filter_size(int or tuple): The first conv's filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
Otherwise, the filter will be a square.
padding(int or tuple, optional): The first conv's padding size. If padding is a tuple,
it must contain two integers, (padding_H, padding_W). Otherwise, the
padding_H = padding_W = padding. Default: 0.
stride(int or tuple, optional): The first conv's stride size. If stride is a tuple,
it must contain two integers, (stride_H, stride_W). Otherwise, the
stride_H = stride_W = stride. Default: 1.
dilation(int or tuple, optional): The first conv's dilation size. If dilation is a tuple,
it must contain two integers, (dilation_H, dilation_W). Otherwise, the
dilation_H = dilation_W = dilation. Default: 1.
norm_layer(class): The normalization layer between two convolution. Default: InstanceNorm.
bias_attr (ParamAttr or bool, optional): The attribute for the bias of convolution.
If it is set to False, no bias will be added to the output units.
If it is set to None or one attribute of ParamAttr, convolution
will create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
scale_factor(float): The scale factor of the first conv's output channel. Default: 1.
use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True.
Returns:
None
"""
def __init__(self,
num_channels,
num_filters,
filter_size,
candidate_config={},
stride=1,
padding=0,
dilation=1,
norm_layer=InstanceNorm,
bias_attr=None,
scale_factor=1,
use_cudnn=False):
super(SuperSeparableConv2D, self).__init__()
self.conv = fluid.dygraph.LayerList([
fluid.dygraph.nn.Conv2D(
num_channels=num_channels,
num_filters=num_channels * scale_factor,
filter_size=filter_size,
stride=stride,
padding=padding,
use_cudnn=False,
groups=num_channels,
bias_attr=bias_attr)
])
self.conv.extend([norm_layer(num_channels * scale_factor)])
self.conv.extend([
Conv2D(
num_channels=num_channels * scale_factor,
num_filters=num_filters,
filter_size=1,
stride=1,
use_cudnn=use_cudnn,
bias_attr=bias_attr)
])
self.candidate_config = candidate_config
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = None
if self.expand_ratio != None:
self.base_output_dim = int(self.output_dim / max(self.expand_ratio))
def forward(self, input, expand_ratio=None, channel=None):
if not in_dygraph_mode():
_logger.error("NOT support static graph")
in_nc = int(input.shape[1])
assert (
expand_ratio == None or channel == None
), "expand_ratio and channel CANNOT be NOT None at the same time."
if expand_ratio != None:
out_nc = int(expand_ratio * self.base_output_dim)
elif channel != None:
out_nc = int(channel)
else:
out_nc = self.conv[0]._num_filters
weight = self.conv[0].weight[:in_nc]
### conv1
if self.conv[0]._l_type == 'conv2d':
attrs = ('strides', self.conv[0]._stride, 'paddings',
self.conv[0]._padding, 'dilations', self.conv[0]._dilation,
'groups', in_nc, 'use_cudnn', self.conv[0]._use_cudnn)
out = core.ops.conv2d(input, weight, *attrs)
elif self.conv[0]._l_type == 'depthwise_conv2d':
attrs = ('strides', self.conv[0]._stride, 'paddings',
self.conv[0]._padding, 'dilations', self.conv[0]._dilation,
'groups', in_nc, 'use_cudnn', self.conv[0]._use_cudnn)
out = core.ops.depthwise_conv2d(input, weight, *attrs)
else:
raise ValueError("conv type error")
pre_bias = out
if self.conv[0].bias is not None:
bias = self.conv[0].bias[:in_nc]
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, 1)
else:
pre_act = pre_bias
conv0_out = dygraph_utils._append_activation_in_dygraph(
pre_act, self.conv[0]._act)
norm_out = self.conv[1](conv0_out)
weight = self.conv[2].weight[:out_nc, :in_nc, :, :]
if self.conv[2]._l_type == 'conv2d':
attrs = ('strides', self.conv[2]._stride, 'paddings',
self.conv[2]._padding, 'dilations', self.conv[2]._dilation,
'groups', self.conv[2]._groups if self.conv[2]._groups else
1, 'use_cudnn', self.conv[2]._use_cudnn)
out = core.ops.conv2d(norm_out, weight, *attrs)
elif self.conv[2]._l_type == 'depthwise_conv2d':
attrs = ('strides', self.conv[2]._stride, 'paddings',
self.conv[2]._padding, 'dilations', self.conv[2]._dilation,
'groups', self.conv[2]._groups, 'use_cudnn',
self.conv[2]._use_cudnn)
out = core.ops.depthwise_conv2d(norm_out, weight, *attrs)
else:
raise ValueError("conv type error")
pre_bias = out
if self.conv[2].bias is not None:
bias = self.conv[2].bias[:out_nc]
pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, 1)
else:
pre_act = pre_bias
conv1_out = dygraph_utils._append_activation_in_dygraph(
pre_act, self.conv[2]._act)
return conv1_out
class SuperLinear(fluid.dygraph.Linear):
"""
"""
def __init__(self,
input_dim,
output_dim,
candidate_config={},
param_attr=None,
bias_attr=None,
act=None,
dtype="float32"):
super(SuperLinear, self).__init__(input_dim, output_dim, param_attr,
bias_attr, act, dtype)
self._param_attr = param_attr
self._bias_attr = bias_attr
self.output_dim = output_dim
self.candidate_config = candidate_config
self.expand_ratio = candidate_config[
'expand_ratio'] if 'expand_ratio' in candidate_config else None
self.base_output_dim = None
if self.expand_ratio != None:
self.base_output_dim = int(self.output_dim / max(self.expand_ratio))
def forward(self, input, expand_ratio=None, channel=None):
if not in_dygraph_mode():
_logger.error("NOT support static graph")
### weight: (Cin, Cout)
in_nc = int(input.shape[1])
assert (
expand_ratio == None or channel == None
), "expand_ratio and channel CANNOT be NOT None at the same time."
if expand_ratio != None:
out_nc = int(expand_ratio * self.base_output_dim)
elif channel != None:
out_nc = int(channel)
else:
out_nc = self.output_dim
weight = self.weight[:in_nc, :out_nc]
if self._bias_attr != False:
bias = self.bias[:out_nc]
use_bias = True
pre_bias = _varbase_creator(dtype=input.dtype)
core.ops.matmul(input, weight, pre_bias, 'transpose_X', False,
'transpose_Y', False, "alpha", 1)
if self._bias_attr != False:
pre_act = dygraph_utils._append_bias_in_dygraph(
pre_bias, bias, axis=len(input.shape) - 1)
else:
pre_act = pre_bias
return dygraph_utils._append_activation_in_dygraph(pre_act, self._act)
class SuperBatchNorm(fluid.dygraph.BatchNorm):
"""
add comment
"""
def __init__(self,
num_channels,
act=None,
is_test=False,
momentum=0.9,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
dtype='float32',
data_layout='NCHW',
in_place=False,
moving_mean_name=None,
moving_variance_name=None,
do_model_average_for_mean_and_var=True,
use_global_stats=False,
trainable_statistics=False):
super(SuperBatchNorm, self).__init__(
num_channels, act, is_test, momentum, epsilon, param_attr,
bias_attr, dtype, data_layout, in_place, moving_mean_name,
moving_variance_name, do_model_average_for_mean_and_var,
use_global_stats, trainable_statistics)
def forward(self, input):
if not in_dygraph_mode():
_logger.error("NOT support static graph")
feature_dim = int(input.shape[1])
weight = self.weight[:feature_dim]
bias = self.bias[:feature_dim]
mean = self._mean[:feature_dim]
variance = self._variance[:feature_dim]
mean_out = mean
variance_out = variance
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout", self._data_layout,
"use_mkldnn", False, "fuse_with_relu", self._fuse_with_relu,
"use_global_stats", self._use_global_stats,
'trainable_statistics', self._trainable_statistics)
batch_norm_out, _, _, _, _, _ = core.ops.batch_norm(
input, weight, bias, mean, variance, mean_out, variance_out, *attrs)
return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act)
class SuperInstanceNorm(fluid.dygraph.InstanceNorm):
"""
"""
def __init__(self,
num_channels,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
dtype='float32'):
super(SuperInstanceNorm, self).__init__(num_channels, epsilon,
param_attr, bias_attr, dtype)
def forward(self, input):
if not in_dygraph_mode():
_logger.error("NOT support static graph")
feature_dim = int(input.shape[1])
if self._param_attr == False and self._bias_attr == False:
scale = None
bias = None
else:
scale = self.scale[:feature_dim]
bias = self.bias[:feature_dim]
out, _, _ = core.ops.instance_norm(input, scale, bias, 'epsilon',
self._epsilon)
return out
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import numpy as np
from collections import namedtuple
import paddle
import paddle.nn as nn
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D
from .layers import BaseBlock, Block, SuperConv2D, SuperBatchNorm
from .utils.utils import search_idx
from ...common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
__all__ = ['OFA', 'RunConfig', 'DistillConfig']
RunConfig = namedtuple('RunConfig', [
'train_batch_size', 'eval_batch_size', 'n_epochs', 'save_frequency',
'eval_frequency', 'init_learning_rate', 'total_images', 'elastic_depth',
'dynamic_batch_size'
])
RunConfig.__new__.__defaults__ = (None, ) * len(RunConfig._fields)
DistillConfig = namedtuple('DistillConfig', [
'lambda_distill', 'teacher_model', 'mapping_layers', 'teacher_model_path',
'distill_fn'
])
DistillConfig.__new__.__defaults__ = (None, ) * len(DistillConfig._fields)
class OFABase(fluid.dygraph.Layer):
def __init__(self, model):
super(OFABase, self).__init__()
self.model = model
self._layers, self._elastic_task = self.get_layers()
def get_layers(self):
layers = dict()
elastic_task = set()
for name, sublayer in self.model.named_sublayers():
if isinstance(sublayer, BaseBlock):
sublayer.set_supernet(self)
layers[sublayer.key] = sublayer.candidate_config
for k in sublayer.candidate_config.keys():
elastic_task.add(k)
return layers, elastic_task
def forward(self, *inputs, **kwargs):
raise NotImplementedError
# NOTE: config means set forward config for layers, used in distill.
def layers_forward(self, block, *inputs, **kwargs):
if getattr(self, 'current_config', None) != None:
assert block.key in self.current_config, 'DONNT have {} layer in config.'.format(
block.key)
config = self.current_config[block.key]
else:
config = dict()
logging.debug(self.model, config)
return block.fn(*inputs, **config)
@property
def layers(self):
return self._layers
class OFA(OFABase):
def __init__(self,
model,
run_config,
net_config=None,
distill_config=None,
elastic_order=None,
train_full=False):
super(OFA, self).__init__(model)
self.net_config = net_config
self.run_config = run_config
self.distill_config = distill_config
self.elastic_order = elastic_order
self.train_full = train_full
self.iter_per_epochs = self.run_config.total_images // self.run_config.train_batch_size
self.iter = 0
self.dynamic_iter = 0
self.manual_set_task = False
self.task_idx = 0
self._add_teacher = False
self.netAs_param = []
for idx in range(len(run_config.n_epochs)):
assert isinstance(
run_config.init_learning_rate[idx],
list), "each candidate in init_learning_rate must be list"
assert isinstance(run_config.n_epochs[idx],
list), "each candidate in n_epochs must be list"
### if elastic_order is none, use default order
if self.elastic_order is not None:
assert isinstance(self.elastic_order,
list), 'elastic_order must be a list'
if self.elastic_order is None:
self.elastic_order = []
# zero, elastic resulotion, write in demo
# first, elastic kernel size
if 'kernel_size' in self._elastic_task:
self.elastic_order.append('kernel_size')
# second, elastic depth, such as: list(2, 3, 4)
if getattr(self.run_config, 'elastic_depth', None) != None:
depth_list = list(set(self.run_config.elastic_depth))
depth_list.sort()
self.layers['depth'] = depth_list
self.elastic_order.append('depth')
# final, elastic width
if 'expand_ratio' in self._elastic_task:
self.elastic_order.append('width')
if 'channel' in self._elastic_task and 'width' not in self.elastic_order:
self.elastic_order.append('width')
assert len(self.run_config.n_epochs) == len(self.elastic_order)
assert len(self.run_config.n_epochs) == len(
self.run_config.dynamic_batch_size)
assert len(self.run_config.n_epochs) == len(
self.run_config.init_learning_rate)
### ================= add distill prepare ======================
if self.distill_config != None and (
self.distill_config.lambda_distill != None and
self.distill_config.lambda_distill > 0):
self._add_teacher = True
self._prepare_distill()
self.model.train()
def _prepare_distill(self):
self.Tacts, self.Sacts = {}, {}
if self.distill_config.teacher_model == None:
logging.error(
'If you want to add distill, please input class of teacher model'
)
assert isinstance(self.distill_config.teacher_model,
paddle.fluid.dygraph.Layer)
# load teacher parameter
if self.distill_config.teacher_model_path != None:
param_state_dict, _ = paddle.load_dygraph(
self.distill_config.teacher_model_path)
self.distill_config.teacher_model.set_dict(param_state_dict)
self.ofa_teacher_model = OFABase(self.distill_config.teacher_model)
self.ofa_teacher_model.model.eval()
# add hook if mapping layers is not None
# if mapping layer is None, return the output of the teacher model,
# if mapping layer is NOT None, add hook and compute distill loss about mapping layers.
mapping_layers = self.distill_config.mapping_layers
if mapping_layers != None:
self.netAs = []
for name, sublayer in self.model.named_sublayers():
if name in mapping_layers:
netA = SuperConv2D(
sublayer._num_filters,
sublayer._num_filters,
filter_size=1)
self.netAs_param.extend(netA.parameters())
self.netAs.append(netA)
def get_activation(mem, name):
def get_output_hook(layer, input, output):
mem[name] = output
return get_output_hook
def add_hook(net, mem, mapping_layers):
for idx, (n, m) in enumerate(net.named_sublayers()):
if n in mapping_layers:
m.register_forward_post_hook(get_activation(mem, n))
add_hook(self.model, self.Sacts, mapping_layers)
add_hook(self.ofa_teacher_model.model, self.Tacts, mapping_layers)
def _compute_epochs(self):
if getattr(self, 'epoch', None) == None:
epoch = self.iter // self.iter_per_epochs
else:
epoch = self.epochs
return epoch
def _sample_from_nestdict(self, cands, sample_type, task, phase):
sample_cands = dict()
for k, v in cands.items():
if isinstance(v, dict):
sample_cands[k] = self._sample_from_nestdict(
v, sample_type=sample_type, task=task, phase=phase)
elif isinstance(v, list) or isinstance(v, set) or isinstance(v,
tuple):
if sample_type == 'largest':
sample_cands[k] = v[-1]
elif sample_type == 'smallest':
sample_cands[k] = v[0]
else:
if k not in task:
# sort and deduplication in candidate_config
# fixed candidate not in task_list
sample_cands[k] = v[-1]
else:
# phase == None -> all candidate; phase == number, append small candidate in each phase
# phase only affect last task in current task_list
if phase != None and k == task[-1]:
start = -(phase + 2)
else:
start = 0
sample_cands[k] = np.random.choice(v[start:])
return sample_cands
def _sample_config(self, task, sample_type='random', phase=None):
config = self._sample_from_nestdict(
self.layers, sample_type=sample_type, task=task, phase=phase)
return config
def set_task(self, task=None, phase=None):
self.manual_set_task = True
self.task = task
self.phase = phase
def set_epoch(self, epoch):
self.epoch = epoch
def _progressive_shrinking(self):
epoch = self._compute_epochs()
self.task_idx, phase_idx = search_idx(epoch, self.run_config.n_epochs)
self.task = self.elastic_order[:self.task_idx + 1]
if 'width' in self.task:
### change width in task to concrete config
self.task.remove('width')
if 'expand_ratio' in self._elastic_task:
self.task.append('expand_ratio')
if 'channel' in self._elastic_task:
self.task.append('channel')
if len(self.run_config.n_epochs[self.task_idx]) == 1:
phase_idx = None
return self._sample_config(task=self.task, phase=phase_idx)
def calc_distill_loss(self):
losses = []
assert len(self.netAs) > 0
for i, netA in enumerate(self.netAs):
assert isinstance(netA, SuperConv2D)
n = self.distill_config.mapping_layers[i]
Tact = self.Tacts[n]
Sact = self.Sacts[n]
Sact = netA(Sact, channel=netA._num_filters)
if self.distill_config.distill_fn == None:
loss = fluid.layers.mse_loss(Sact, Tact)
else:
loss = distill_fn(Sact, Tact)
losses.append(loss)
return sum(losses) * self.distill_config.lambda_distill
### TODO: complete it
def search(self, eval_func, condition):
pass
### TODO: complete it
def export(self, config):
pass
def forward(self, *inputs, **kwargs):
# ===================== teacher process =====================
teacher_output = None
if self._add_teacher:
teacher_output = self.ofa_teacher_model.model.forward(*inputs,
**kwargs)
# ============================================================
# ==================== student process =====================
self.dynamic_iter += 1
if self.dynamic_iter == self.run_config.dynamic_batch_size[
self.task_idx]:
self.iter += 1
self.dynamic_iter = 0
if self.net_config == None:
if self.train_full == True:
self.current_config = self._sample_config(
task=None, sample_type='largest')
else:
if self.manual_set_task == False:
self.current_config = self._progressive_shrinking()
else:
self.current_config = self._sample_config(
self.task, phase=self.phase)
else:
self.current_config = self.net_config
_logger.debug("Current config is {}".format(self.current_config))
if 'depth' in self.current_config:
kwargs['depth'] = int(self.current_config['depth'])
return self.model.forward(*inputs, **kwargs), teacher_output
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .utils import *
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def compute_start_end(kernel_size, sub_kernel_size):
center = kernel_size // 2
sub_center = sub_kernel_size // 2
start = center - sub_center
end = center + sub_center + 1
assert end - start == sub_kernel_size
return start, end
def get_same_padding(kernel_size):
assert isinstance(kernel_size, int)
assert kernel_size % 2 > 0, "kernel size must be odd number"
return kernel_size // 2
def convert_to_list(value, n):
return [value, ] * n
def search_idx(num, sorted_nestlist):
max_num = -1
max_idx = -1
for idx in range(len(sorted_nestlist)):
task_ = sorted_nestlist[idx]
max_num = task_[-1]
max_idx = len(task_) - 1
for phase_idx in range(len(task_)):
if num <= task_[phase_idx]:
return idx, phase_idx
assert num > max_num
return len(sorted_nestlist) - 1, max_idx
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import numpy as np
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph.nn as nn
from paddle.nn import ReLU
from paddleslim.nas import ofa
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
from paddleslim.nas.ofa.convert_super import supernet
from paddleslim.nas.ofa.layers import Block, SuperSeparableConv2D
class ModelConv(fluid.dygraph.Layer):
def __init__(self):
super(ModelConv, self).__init__()
with supernet(
kernel_size=(3, 5, 7),
channel=((4, 8, 12), (8, 12, 16), (8, 12, 16),
(8, 12, 16))) as ofa_super:
models = []
models += [nn.Conv2D(3, 4, 3)]
models += [nn.InstanceNorm(4)]
models += [ReLU()]
models += [nn.Conv2D(4, 4, 3, groups=4)]
models += [nn.InstanceNorm(4)]
models += [ReLU()]
models += [nn.Conv2DTranspose(4, 4, 3, groups=4, use_cudnn=True)]
models += [nn.BatchNorm(4)]
models += [ReLU()]
models += [nn.Conv2D(4, 3, 3)]
models += [ReLU()]
models = ofa_super.convert(models)
models += [
Block(
SuperSeparableConv2D(
3, 6, 1, candidate_config={'channel': (3, 6)}))
]
with supernet(
kernel_size=(3, 5, 7), expand_ratio=(1, 2, 4)) as ofa_super:
models1 = []
models1 += [nn.Conv2D(6, 4, 3)]
models1 += [nn.BatchNorm(4)]
models1 += [ReLU()]
models1 += [nn.Conv2D(4, 4, 3, groups=2)]
models1 += [nn.InstanceNorm(4)]
models1 += [ReLU()]
models1 += [nn.Conv2DTranspose(4, 4, 3, groups=2)]
models1 += [nn.BatchNorm(4)]
models1 += [ReLU()]
models1 += [nn.Conv2DTranspose(4, 4, 3)]
models1 += [nn.BatchNorm(4)]
models1 += [ReLU()]
models1 = ofa_super.convert(models1)
models += models1
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs, depth=None):
if depth != None:
assert isinstance(depth, int)
assert depth <= len(self.models)
else:
depth = len(self.models)
for idx in range(depth):
layer = self.models[idx]
inputs = layer(inputs)
return inputs
class ModelLinear(fluid.dygraph.Layer):
def __init__(self):
super(ModelLinear, self).__init__()
models = []
with supernet(expand_ratio=(1, 2, 4)) as ofa_super:
models1 = []
models1 += [nn.Linear(64, 128)]
models1 += [nn.Linear(128, 256)]
models1 = ofa_super.convert(models1)
models += models1
with supernet(channel=((64, 128, 256), (64, 128, 256))) as ofa_super:
models1 = []
models1 += [nn.Linear(256, 128)]
models1 += [nn.Linear(128, 256)]
models1 = ofa_super.convert(models1)
models += models1
self.models = paddle.nn.Sequential(*models)
def forward(self, inputs, depth=None):
if depth != None:
assert isinstance(depth, int)
assert depth < len(self.models)
else:
depth = len(self.models)
for idx in range(depth):
layer = self.models[idx]
inputs = layer(inputs)
return inputs
class TestOFA(unittest.TestCase):
def setUp(self):
fluid.enable_dygraph()
self.init_model_and_data()
self.init_config()
def init_model_and_data(self):
self.model = ModelConv()
self.teacher_model = ModelConv()
data_np = np.random.random((1, 3, 10, 10)).astype(np.float32)
label_np = np.random.random((1)).astype(np.float32)
self.data = fluid.dygraph.to_variable(data_np)
def init_config(self):
default_run_config = {
'train_batch_size': 1,
'eval_batch_size': 1,
'n_epochs': [[1], [2, 3], [4, 5]],
'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]],
'dynamic_batch_size': [1, 1, 1],
'total_images': 1,
'elastic_depth': (5, 15, 24)
}
self.run_config = RunConfig(**default_run_config)
default_distill_config = {
'lambda_distill': 0.01,
'teacher_model': self.teacher_model,
'mapping_layers': ['models.0.fn']
}
self.distill_config = DistillConfig(**default_distill_config)
def test_ofa(self):
ofa_model = OFA(self.model,
self.run_config,
distill_config=self.distill_config)
start_epoch = 0
for idx in range(len(self.run_config.n_epochs)):
cur_idx = self.run_config.n_epochs[idx]
for ph_idx in range(len(cur_idx)):
cur_lr = self.run_config.init_learning_rate[idx][ph_idx]
adam = fluid.optimizer.Adam(
learning_rate=cur_lr,
parameter_list=(
ofa_model.parameters() + ofa_model.netAs_param))
for epoch_id in range(start_epoch,
self.run_config.n_epochs[idx][ph_idx]):
for model_no in range(self.run_config.dynamic_batch_size[
idx]):
output, _ = ofa_model(self.data)
loss = fluid.layers.reduce_mean(output)
if self.distill_config.mapping_layers != None:
dis_loss = ofa_model.calc_distill_loss()
loss += dis_loss
dis_loss = dis_loss.numpy()[0]
else:
dis_loss = 0
print('epoch: {}, loss: {}, distill loss: {}'.format(
epoch_id, loss.numpy()[0], dis_loss))
loss.backward()
adam.minimize(loss)
adam.clear_gradients()
start_epoch = self.run_config.n_epochs[idx][ph_idx]
class TestOFACase1(TestOFA):
def init_model_and_data(self):
self.model = ModelLinear()
self.teacher_model = ModelLinear()
data_np = np.random.random((3, 64)).astype(np.float32)
self.data = fluid.dygraph.to_variable(data_np)
def init_config(self):
default_run_config = {
'train_batch_size': 1,
'eval_batch_size': 1,
'n_epochs': [[2, 5]],
'init_learning_rate': [[0.003, 0.001]],
'dynamic_batch_size': [1],
'total_images': 1,
}
self.run_config = RunConfig(**default_run_config)
default_distill_config = {
'lambda_distill': 0.01,
'teacher_model': self.teacher_model,
}
self.distill_config = DistillConfig(**default_distill_config)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册