From 3fe05f00e5e5a69bf7000c303520e2264291f4f0 Mon Sep 17 00:00:00 2001 From: ceci3 <592712189@qq.com> Date: Mon, 13 Jul 2020 03:19:14 +0000 Subject: [PATCH] add comment --- paddleslim/core/__init__.py | 4 +- paddleslim/core/layers.py | 320 +++++++++++++++++++++------ paddleslim/models/dygraph/modules.py | 4 +- 3 files changed, 256 insertions(+), 72 deletions(-) diff --git a/paddleslim/core/__init__.py b/paddleslim/core/__init__.py index a9729102..404c6593 100644 --- a/paddleslim/core/__init__.py +++ b/paddleslim/core/__init__.py @@ -14,6 +14,6 @@ from .graph_wrapper import GraphWrapper, VarWrapper, OpWrapper from .registry import Registry -from .layers import SuperInstanceNorm, SuperConv2D, SuperConv2DTranspose, SuperSeparableConv2D +from .layers import SuperConv2D, SuperConv2DTranspose, SuperSeparableConv2D -__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper', 'Registry', 'SuperInstanceNorm', 'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D'] +__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper', 'Registry', 'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D'] diff --git a/paddleslim/core/layers.py b/paddleslim/core/layers.py index d676e4e0..efe5be3e 100644 --- a/paddleslim/core/layers.py +++ b/paddleslim/core/layers.py @@ -14,65 +14,117 @@ import paddle.fluid as fluid import paddle.fluid.dygraph_utils as dygraph_utils -from paddle.fluid.data_feeder import check_variable_and_dtype, check_type -from paddle.fluid.dygraph.base import to_variable +from paddle.fluid.data_feeder import check_variable_and_dtype from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose import paddle.fluid.core as core import numpy as np -__all__ = ['SuperInstanceNorm', 'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D'] - - -### NOTE: this op can delete after this pr merged: https://github.com/PaddlePaddle/Paddle/pull/24717 -class SuperInstanceNorm(fluid.dygraph.InstanceNorm): - def __init__(self, - num_channels, - epsilon=1e-5, - param_attr=None, - bias_attr=None, - dtype='float32'): - super(SuperInstanceNorm, self).__init__( - num_channels, - epsilon=1e-5, - param_attr=None, - bias_attr=None, - dtype='float32') - - def forward(self, input): - in_nc = int(input.shape[1]) - scale = self.scale[:in_nc] - bias = self.scale[:in_nc] - if in_dygraph_mode(): - out, _, _ = core.ops.instance_norm(input, scale, bias, 'epsilon', - self._epsilon) - return out - check_variable_and_dtype(input, 'input', ['float32', 'float64'], - "SuperInstanceNorm") - - attrs = {"epsilon": self._epsilon} - - inputs = {"X": [input], "Scale": [scale], "Bias": [bias]} - - saved_mean = self._helper.create_variable_for_type_inference( - dtype=self._dtype, stop_gradient=True) - saved_variance = self._helper.create_variable_for_type_inference( - dtype=self._dtype, stop_gradient=True) - instance_norm_out = self._helper.create_variable_for_type_inference( - self._dtype) - - outputs = { - "Y": [instance_norm_out], - "SavedMean": [saved_mean], - "SavedVariance": [saved_variance] - } - - self._helper.append_op( - type="instance_norm", inputs=inputs, outputs=outputs, attrs=attrs) - return instance_norm_out - +__all__ = ['SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D'] class SuperConv2D(fluid.dygraph.Conv2D): + """ + This interface is used to construct a callable object of the ``SuperConv2D`` class. + The difference between ```SuperConv2D``` and ```Conv2D``` is: ```SuperConv2D``` need + to feed a config dictionary with the format of {'channel', num_of_channel} represents + the channels of the outputs, used to change the first dimension of weight and bias, + only train the first channels of the weight and bias. + + Note: the channel in config need to less than first defined. + + The super convolution2D layer calculates the output based on the input, filter + and strides, paddings, dilations, groups parameters. Input and + Output are in NCHW format, where N is batch size, C is the number of + the feature map, H is the height of the feature map, and W is the width of the feature map. + Filter's shape is [MCHW] , where M is the number of output feature map, + C is the number of input feature map, H is the height of the filter, + and W is the width of the filter. If the groups is greater than 1, + C will equal the number of input feature map divided by the groups. + Please refer to UFLDL's `convolution + `_ + for more details. + If bias attribution and activation type are provided, bias is added to the + output of the convolution, and the corresponding activation function is + applied to the final result. + For each input :math:`X`, the equation is: + .. math:: + Out = \\sigma (W \\ast X + b) + Where: + * :math:`X`: Input value, a ``Tensor`` with NCHW format. + * :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] . + * :math:`\\ast`: Convolution operation. + * :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1]. + * :math:`\\sigma`: Activation function. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + + Example: + - Input: + Input shape: :math:`(N, C_{in}, H_{in}, W_{in})` + Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)` + - Output: + Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` + Where + .. math:: + H_{out}&= \\frac{(H_{in} + 2 * paddings[0] - (dilations[0] * (H_f - 1) + 1))}{strides[0]} + 1 \\\\ + W_{out}&= \\frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]} + 1 + Parameters: + num_channels(int): The number of channels in the input image. + num_filters(int): The number of filter. It is as same as the output + feature map. + filter_size (int or tuple): The filter size. If filter_size is a tuple, + it must contain two integers, (filter_size_H, filter_size_W). + Otherwise, the filter will be a square. + stride (int or tuple, optional): The stride size. If stride is a tuple, it must + contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: 1. + padding (int or tuple, optional): The padding size. If padding is a tuple, it must + contain two integers, (padding_H, padding_W). Otherwise, the + padding_H = padding_W = padding. Default: 0. + dilation (int or tuple, optional): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: 1. + groups (int, optional): The groups number of the Conv2d Layer. According to grouped + convolution in Alex Krizhevsky's Deep CNN paper: when group=2, + the first half of the filters is only connected to the first half + of the input channels, while the second half of the filters is only + connected to the second half of the input channels. Default: 1. + param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter) + of conv2d. If it is set to None or one attribute of ParamAttr, conv2d + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with :math:`Normal(0.0, std)`, + and the :math:`std` is :math:`(\\frac{2.0 }{filter\_elem\_num})^{0.5}`. Default: None. + bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, conv2d + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + use_cudnn (bool, optional): Use cudnn kernel or not, it is valid only when the cudnn + library is installed. Default: True. + act (str, optional): Activation type, if it is set to None, activation is not appended. + Default: None. + dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32". + Attribute: + **weight** (Parameter): the learnable weights of filter of this layer. + **bias** (Parameter or None): the learnable bias of this layer. + Returns: + None + + Raises: + ValueError: if ``use_cudnn`` is not a bool value. + Examples: + .. code-block:: python + from paddle.fluid.dygraph.base import to_variable + import paddle.fluid as fluid + from paddleslim.core.layers import SuperConv2D + import numpy as np + data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') + with fluid.dygraph.guard(): + super_conv2d = SuperConv2D(3, 10, 3) + config = {'channel': 5} + data = to_variable(data) + conv = super_conv2d(data, config) + + """ def __init__(self, num_channels, num_filters, @@ -94,7 +146,6 @@ class SuperConv2D(fluid.dygraph.Conv2D): in_nc = int(input.shape[1]) out_nc = config['channel'] weight = self.weight[:out_nc, :in_nc, :, :] - #print('super conv shape', weight.shape) if in_dygraph_mode(): if self._l_type == 'conv2d': attrs = ('strides', self._stride, 'paddings', self._padding, @@ -161,6 +212,109 @@ class SuperConv2D(fluid.dygraph.Conv2D): class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): + """ + This interface is used to construct a callable object of the ``SuperConv2DTranspose`` + class. + The difference between ```SuperConv2DTranspose``` and ```Conv2DTranspose``` is: + ```SuperConv2DTranspose``` need to feed a config dictionary with the format of + {'channel', num_of_channel} represents the channels of the outputs, used to change + the first dimension of weight and bias, only train the first channels of the weight + and bias. + + Note: the channel in config need to less than first defined. + + The super convolution2D transpose layer calculates the output based on the input, + filter, and dilations, strides, paddings. Input and output + are in NCHW format. Where N is batch size, C is the number of feature map, + H is the height of the feature map, and W is the width of the feature map. + Filter's shape is [MCHW] , where M is the number of input feature map, + C is the number of output feature map, H is the height of the filter, + and W is the width of the filter. If the groups is greater than 1, + C will equal the number of input feature map divided by the groups. + If bias attribution and activation type are provided, bias is added to + the output of the convolution, and the corresponding activation function + is applied to the final result. + The details of convolution transpose layer, please refer to the following explanation and references + `conv2dtranspose `_ . + For each input :math:`X`, the equation is: + .. math:: + Out = \sigma (W \\ast X + b) + Where: + * :math:`X`: Input value, a ``Tensor`` with NCHW format. + * :math:`W`: Filter value, a ``Tensor`` with shape [MCHW] . + * :math:`\\ast`: Convolution operation. + * :math:`b`: Bias value, a 2-D ``Tensor`` with shape [M, 1]. + * :math:`\\sigma`: Activation function. + * :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different. + Example: + - Input: + Input shape: :math:`(N, C_{in}, H_{in}, W_{in})` + Filter shape: :math:`(C_{in}, C_{out}, H_f, W_f)` + - Output: + Output shape: :math:`(N, C_{out}, H_{out}, W_{out})` + Where + .. math:: + H^\prime_{out} &= (H_{in} - 1) * strides[0] - 2 * paddings[0] + dilations[0] * (H_f - 1) + 1 \\\\ + W^\prime_{out} &= (W_{in} - 1) * strides[1] - 2 * paddings[1] + dilations[1] * (W_f - 1) + 1 \\\\ + H_{out} &\in [ H^\prime_{out}, H^\prime_{out} + strides[0] ) \\\\ + W_{out} &\in [ W^\prime_{out}, W^\prime_{out} + strides[1] ) + Parameters: + num_channels(int): The number of channels in the input image. + num_filters(int): The number of the filter. It is as same as the output + feature map. + filter_size(int or tuple): The filter size. If filter_size is a tuple, + it must contain two integers, (filter_size_H, filter_size_W). + Otherwise, the filter will be a square. + output_size(int or tuple, optional): The output image size. If output size is a + tuple, it must contain two integers, (image_H, image_W). None if use + filter_size, padding, and stride to calculate output_size. + if output_size and filter_size are specified at the same time, They + should follow the formula above. Default: None. + padding(int or tuple, optional): The padding size. If padding is a tuple, it must + contain two integers, (padding_H, padding_W). Otherwise, the + padding_H = padding_W = padding. Default: 0. + stride(int or tuple, optional): The stride size. If stride is a tuple, it must + contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: 1. + dilation(int or tuple, optional): The dilation size. If dilation is a tuple, it must + contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: 1. + groups(int, optional): The groups number of the Conv2d transpose layer. Inspired by + grouped convolution in Alex Krizhevsky's Deep CNN paper, in which + when group=2, the first half of the filters is only connected to the + first half of the input channels, while the second half of the + filters is only connected to the second half of the input channels. + Default: 1. + param_attr (ParamAttr, optional): The parameter attribute for learnable weights(Parameter) + of conv2d_transpose. If it is set to None or one attribute of ParamAttr, conv2d_transpose + will create ParamAttr as param_attr. If the Initializer of the param_attr + is not set, the parameter is initialized with Xavier. Default: None. + bias_attr (ParamAttr or bool, optional): The attribute for the bias of conv2d_transpose. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, conv2d_transpose + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn + library is installed. Default: True. + act (str, optional): Activation type, if it is set to None, activation is not appended. + Default: None. + dtype (str, optional): Data type, it can be "float32" or "float64". Default: "float32". + Attribute: + **weight** (Parameter): the learnable weights of filters of this layer. + **bias** (Parameter or None): the learnable bias of this layer. + Returns: + None + Examples: + .. code-block:: python + import paddle.fluid as fluid + from paddleslim.core.layers import SuperConv2DTranspose + import numpy as np + with fluid.dygraph.guard(): + data = np.random.random((3, 32, 32, 5)).astype('float32') + config = {'channel': 5 + super_convtranspose = SuperConv2DTranspose(num_channels=32, num_filters=10, filter_size=3) + ret = super_convtranspose(fluid.dygraph.base.to_variable(data), config) + """ def __init__(self, num_channels, num_filters, @@ -240,6 +394,48 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): class SuperSeparableConv2D(fluid.dygraph.Layer): + """ + This interface is used to construct a callable object of the ``SuperSeparableConv2D`` + class. + The difference between ```SuperSeparableConv2D``` and ```SeparableConv2D``` is: + ```SuperSeparableConv2D``` need to feed a config dictionary with the format of + {'channel', num_of_channel} represents the channels of the first conv's outputs and + the second conv's inputs, used to change the first dimension of weight and bias, + only train the first channels of the weight and bias. + + The architecture of super separable convolution2D op is [Conv2D, norm layer(may be BatchNorm + or InstanceNorm), Conv2D]. The first conv is depthwise conv, the filter number is input channel + multiply scale_factor, the group is equal to the number of input channel. The second conv + is standard conv, which filter size and stride size are 1. + + Parameters: + num_channels(int): The number of channels in the input image. + num_filters(int): The number of the second conv's filter. It is as same as the output + feature map. + filter_size(int or tuple): The first conv's filter size. If filter_size is a tuple, + it must contain two integers, (filter_size_H, filter_size_W). + Otherwise, the filter will be a square. + padding(int or tuple, optional): The first conv's padding size. If padding is a tuple, + it must contain two integers, (padding_H, padding_W). Otherwise, the + padding_H = padding_W = padding. Default: 0. + stride(int or tuple, optional): The first conv's stride size. If stride is a tuple, + it must contain two integers, (stride_H, stride_W). Otherwise, the + stride_H = stride_W = stride. Default: 1. + dilation(int or tuple, optional): The first conv's dilation size. If dilation is a tuple, + it must contain two integers, (dilation_H, dilation_W). Otherwise, the + dilation_H = dilation_W = dilation. Default: 1. + norm_layer(class): The normalization layer between two convolution. Default: InstanceNorm. + bias_attr (ParamAttr or bool, optional): The attribute for the bias of convolution. + If it is set to False, no bias will be added to the output units. + If it is set to None or one attribute of ParamAttr, convolution + will create ParamAttr as bias_attr. If the Initializer of the bias_attr + is not set, the bias is initialized zero. Default: None. + scale_factor(float): The scale factor of the first conv's output channel. Default: 1. + use_cudnn(bool, optional): Use cudnn kernel or not, it is valid only when the cudnn + library is installed. Default: True. + Returns: + None + """ def __init__(self, num_channels, num_filters, @@ -263,21 +459,9 @@ class SuperSeparableConv2D(fluid.dygraph.Layer): groups=num_channels, bias_attr=bias_attr) ]) - if norm_layer == InstanceNorm: - self.conv.extend([ - SuperInstanceNorm( - num_channels * scale_factor, - param_attr=fluid.ParamAttr( - initializer=fluid.initializer.Constant(1.0), - learning_rate=0.0, - trainable=False), - bias_attr=fluid.ParamAttr( - initializer=fluid.initializer.Constant(0.0), - learning_rate=0.0, - trainable=False)) - ]) - else: - raise NotImplementedError + + self.conv.extend([norm_layer(num_channels * scale_factor)]) + self.conv.extend([ Conv2D( num_channels=num_channels * scale_factor, diff --git a/paddleslim/models/dygraph/modules.py b/paddleslim/models/dygraph/modules.py index 41d80e44..b9c4c7fd 100644 --- a/paddleslim/models/dygraph/modules.py +++ b/paddleslim/models/dygraph/modules.py @@ -31,7 +31,7 @@ class SeparableConv2D(fluid.dygraph.Layer): use_bias=True, scale_factor=1, stddev=0.02, - use_cudnn=use_cudnn): + use_cudnn=False): super(SeparableConv2D, self).__init__() self.conv = fluid.dygraph.LayerList([ @@ -41,7 +41,7 @@ class SeparableConv2D(fluid.dygraph.Layer): filter_size=filter_size, stride=stride, padding=padding, - use_cudnn=False, + use_cudnn=use_cudnn, groups=num_channels, param_attr=fluid.ParamAttr( initializer=fluid.initializer.NormalInitializer( -- GitLab