diff --git a/paddleslim/core/__init__.py b/paddleslim/core/__init__.py index 1f9fab5c3e20facf2e18e3946513ee05df742e41..a9729102d71c9bebc2ac379a7b7598a32b33c08e 100644 --- a/paddleslim/core/__init__.py +++ b/paddleslim/core/__init__.py @@ -14,5 +14,6 @@ from .graph_wrapper import GraphWrapper, VarWrapper, OpWrapper from .registry import Registry +from .layers import SuperInstanceNorm, SuperConv2D, SuperConv2DTranspose, SuperSeparableConv2D -__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper', 'Registry'] +__all__ = ['GraphWrapper', 'VarWrapper', 'OpWrapper', 'Registry', 'SuperInstanceNorm', 'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D'] diff --git a/paddleslim/core/layers.py b/paddleslim/core/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d676e4e03b334deb38af12ad36dd7b84f7338720 --- /dev/null +++ b/paddleslim/core/layers.py @@ -0,0 +1,356 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle.fluid.dygraph_utils as dygraph_utils +from paddle.fluid.data_feeder import check_variable_and_dtype, check_type +from paddle.fluid.dygraph.base import to_variable +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.dygraph.nn import InstanceNorm, Conv2D, Conv2DTranspose +import paddle.fluid.core as core +import numpy as np + +__all__ = ['SuperInstanceNorm', 'SuperConv2D', 'SuperConv2DTranspose', 'SuperSeparableConv2D'] + + +### NOTE: this op can delete after this pr merged: https://github.com/PaddlePaddle/Paddle/pull/24717 +class SuperInstanceNorm(fluid.dygraph.InstanceNorm): + def __init__(self, + num_channels, + epsilon=1e-5, + param_attr=None, + bias_attr=None, + dtype='float32'): + super(SuperInstanceNorm, self).__init__( + num_channels, + epsilon=1e-5, + param_attr=None, + bias_attr=None, + dtype='float32') + + def forward(self, input): + in_nc = int(input.shape[1]) + scale = self.scale[:in_nc] + bias = self.scale[:in_nc] + if in_dygraph_mode(): + out, _, _ = core.ops.instance_norm(input, scale, bias, 'epsilon', + self._epsilon) + return out + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + "SuperInstanceNorm") + + attrs = {"epsilon": self._epsilon} + + inputs = {"X": [input], "Scale": [scale], "Bias": [bias]} + + saved_mean = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + saved_variance = self._helper.create_variable_for_type_inference( + dtype=self._dtype, stop_gradient=True) + instance_norm_out = self._helper.create_variable_for_type_inference( + self._dtype) + + outputs = { + "Y": [instance_norm_out], + "SavedMean": [saved_mean], + "SavedVariance": [saved_variance] + } + + self._helper.append_op( + type="instance_norm", inputs=inputs, outputs=outputs, attrs=attrs) + return instance_norm_out + + +class SuperConv2D(fluid.dygraph.Conv2D): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + padding=0, + dilation=1, + groups=None, + param_attr=None, + bias_attr=None, + use_cudnn=True, + act=None, + dtype='float32'): + super(SuperConv2D, self).__init__( + num_channels, num_filters, filter_size, stride, padding, dilation, + groups, param_attr, bias_attr, use_cudnn, act, dtype) + + def forward(self, input, config): + in_nc = int(input.shape[1]) + out_nc = config['channel'] + weight = self.weight[:out_nc, :in_nc, :, :] + #print('super conv shape', weight.shape) + if in_dygraph_mode(): + if self._l_type == 'conv2d': + attrs = ('strides', self._stride, 'paddings', self._padding, + 'dilations', self._dilation, 'groups', self._groups + if self._groups else 1, 'use_cudnn', self._use_cudnn) + out = core.ops.conv2d(input, weight, *attrs) + elif self._l_type == 'depthwise_conv2d': + attrs = ('strides', self._stride, 'paddings', self._padding, + 'dilations', self._dilation, 'groups', self._groups, + 'use_cudnn', self._use_cudnn) + out = core.ops.depthwise_conv2d(input, weight, *attrs) + else: + raise ValueError("conv type error") + + pre_bias = out + if self.bias is not None: + bias = self.bias[:out_nc] + pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, + 1) + else: + pre_act = pre_bias + + return dygraph_utils._append_activation_in_dygraph(pre_act, + self._act) + + inputs = {'Input': [input], 'Filter': [weight]} + attrs = { + 'strides': self._stride, + 'paddings': self._padding, + 'dilations': self._dilation, + 'groups': self._groups if self._groups else 1, + 'use_cudnn': self._use_cudnn, + 'use_mkldnn': False, + } + check_variable_and_dtype( + input, 'input', ['float16', 'float32', 'float64'], 'SuperConv2D') + pre_bias = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + + self._helper.append_op( + type=self._l_type, + inputs={ + 'Input': input, + 'Filter': weight, + }, + outputs={"Output": pre_bias}, + attrs=attrs) + + if self.bias is not None: + bias = self.bias[:out_nc] + pre_act = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], + 'Y': [bias]}, + outputs={'Out': [pre_act]}, + attrs={'axis': 1}) + else: + pre_act = pre_bias + + # Currently, we don't support inplace in dygraph mode + return self._helper.append_activation(pre_act, act=self._act) + + +class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): + def __init__(self, + num_channels, + num_filters, + filter_size, + output_size=None, + padding=0, + stride=1, + dilation=1, + groups=None, + param_attr=None, + bias_attr=None, + use_cudnn=True, + act=None, + dtype='float32'): + super(SuperConv2DTranspose, + self).__init__(num_channels, num_filters, filter_size, + output_size, padding, stride, dilation, groups, + param_attr, bias_attr, use_cudnn, act, dtype) + + def forward(self, input, config): + in_nc = int(input.shape[1]) + out_nc = int(config['channel']) + weight = self.weight[:in_nc, :out_nc, :, :] + if in_dygraph_mode(): + op = getattr(core.ops, self._op_type) + out = op(input, weight, 'output_size', self._output_size, + 'strides', self._stride, 'paddings', self._padding, + 'dilations', self._dilation, 'groups', self._groups, + 'use_cudnn', self._use_cudnn) + pre_bias = out + if self.bias is not None: + bias = self.bias[:out_nc] + pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, + 1) + else: + pre_act = pre_bias + + return dygraph_utils._append_activation_in_dygraph( + pre_act, act=self._act) + + check_variable_and_dtype(input, 'input', + ['float16', 'float32', 'float64'], + "SuperConv2DTranspose") + + inputs = {'Input': [input], 'Filter': [weight]} + attrs = { + 'output_size': self._output_size, + 'strides': self._stride, + 'paddings': self._padding, + 'dilations': self._dilation, + 'groups': self._groups, + 'use_cudnn': self._use_cudnn + } + + pre_bias = self._helper.create_variable_for_type_inference( + dtype=input.dtype) + self._helper.append_op( + type=self._op_type, + inputs=inputs, + outputs={'Output': pre_bias}, + attrs=attrs) + + if self.bias is not None: + pre_act = self._helper.create_variable_for_type_inference( + dtype=self._dtype) + self._helper.append_op( + type='elementwise_add', + inputs={'X': [pre_bias], + 'Y': [bias]}, + outputs={'Out': [pre_act]}, + attrs={'axis': 1}) + else: + pre_act = pre_bias + + out = self._helper.append_activation(pre_act, act=self._act) + return out + + +class SuperSeparableConv2D(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + padding=0, + dilation=1, + norm_layer=InstanceNorm, + bias_attr=None, + scale_factor=1, + use_cudnn=False): + super(SuperSeparableConv2D, self).__init__() + self.conv = fluid.dygraph.LayerList([ + fluid.dygraph.nn.Conv2D( + num_channels=num_channels, + num_filters=num_channels * scale_factor, + filter_size=filter_size, + stride=stride, + padding=padding, + use_cudnn=False, + groups=num_channels, + bias_attr=bias_attr) + ]) + if norm_layer == InstanceNorm: + self.conv.extend([ + SuperInstanceNorm( + num_channels * scale_factor, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(1.0), + learning_rate=0.0, + trainable=False), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.0), + learning_rate=0.0, + trainable=False)) + ]) + else: + raise NotImplementedError + self.conv.extend([ + Conv2D( + num_channels=num_channels * scale_factor, + num_filters=num_filters, + filter_size=1, + stride=1, + use_cudnn=use_cudnn, + bias_attr=bias_attr) + ]) + + def forward(self, input, config): + in_nc = int(input.shape[1]) + out_nc = int(config['channel']) + weight = self.conv[0].weight[:in_nc] + ### conv1 + if in_dygraph_mode(): + if self.conv[0]._l_type == 'conv2d': + attrs = ('strides', self.conv[0]._stride, 'paddings', + self.conv[0]._padding, 'dilations', + self.conv[0]._dilation, 'groups', in_nc, 'use_cudnn', + self.conv[0]._use_cudnn) + out = core.ops.conv2d(input, weight, *attrs) + elif self.conv[0]._l_type == 'depthwise_conv2d': + attrs = ('strides', self.conv[0]._stride, 'paddings', + self.conv[0]._padding, 'dilations', + self.conv[0]._dilation, 'groups', in_nc, 'use_cudnn', + self.conv[0]._use_cudnn) + out = core.ops.depthwise_conv2d(input, weight, *attrs) + else: + raise ValueError("conv type error") + + pre_bias = out + if self.conv[0].bias is not None: + bias = self.conv[0].bias[:in_nc] + pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, + 1) + else: + pre_act = pre_bias + + conv0_out = dygraph_utils._append_activation_in_dygraph( + pre_act, self.conv[0]._act) + + norm_out = self.conv[1](conv0_out) + + weight = self.conv[2].weight[:out_nc, :in_nc, :, :] + + if in_dygraph_mode(): + if self.conv[2]._l_type == 'conv2d': + attrs = ('strides', self.conv[2]._stride, 'paddings', + self.conv[2]._padding, 'dilations', + self.conv[2]._dilation, 'groups', self.conv[2]._groups + if self.conv[2]._groups else 1, 'use_cudnn', + self.conv[2]._use_cudnn) + out = core.ops.conv2d(norm_out, weight, *attrs) + elif self.conv[2]._l_type == 'depthwise_conv2d': + attrs = ('strides', self.conv[2]._stride, 'paddings', + self.conv[2]._padding, 'dilations', + self.conv[2]._dilation, 'groups', + self.conv[2]._groups, 'use_cudnn', + self.conv[2]._use_cudnn) + out = core.ops.depthwise_conv2d(norm_out, weight, *attrs) + else: + raise ValueError("conv type error") + + pre_bias = out + if self.conv[2].bias is not None: + bias = self.conv[2].bias[:out_nc] + pre_act = dygraph_utils._append_bias_in_dygraph(pre_bias, bias, + 1) + else: + pre_act = pre_bias + + conv1_out = dygraph_utils._append_activation_in_dygraph( + pre_act, self.conv[2]._act) + return conv1_out + diff --git a/paddleslim/models/__init__.py b/paddleslim/models/__init__.py index 14ea9f3d15fa953f0c4dba47aee6bc45a0e1ee62..e37afd9882ca4cbbaa9a1ad5a07be83069788348 100644 --- a/paddleslim/models/__init__.py +++ b/paddleslim/models/__init__.py @@ -13,7 +13,9 @@ # limitations under the License. from __future__ import absolute_import +from .dygraph import modules from .util import image_classification from .slimfacenet import SlimFaceNet_A_x0_60, SlimFaceNet_B_x0_75, SlimFaceNet_C_x0_75 from .slim_mobilenet import SlimMobileNet_v1, SlimMobileNet_v2, SlimMobileNet_v3, SlimMobileNet_v4, SlimMobileNet_v5 __all__ = ["image_classification"] +__all__ += modules.__all__ diff --git a/paddleslim/models/dygraph/__init__.py b/paddleslim/models/dygraph/__init__.py index d618ee708b3c9d594dd5e6b02d9ee75504452c38..a7d6d47b87cd98f72090366533e4124b5f30f6d3 100644 --- a/paddleslim/models/dygraph/__init__.py +++ b/paddleslim/models/dygraph/__init__.py @@ -15,5 +15,6 @@ from __future__ import absolute_import from .mobilenet import MobileNetV1 from .resnet import ResNet +from .modules import SeparableConv2D, MobileResnetBlock, ResnetBlock -__all__ = ["MobileNetV1", "ResNet"] +__all__ = ["MobileNetV1", "ResNet", "SeparableConv2", "MobileResnetBlock", "ResnetBlock"] diff --git a/paddleslim/models/dygraph/modules.py b/paddleslim/models/dygraph/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..41d80e445539214b6a3a52cd3bb885bd68b062be --- /dev/null +++ b/paddleslim/models/dygraph/modules.py @@ -0,0 +1,192 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +import paddle.fluid as fluid +from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, BatchNorm, InstanceNorm, Dropout +from paddle.nn.layer import ReLU, Pad2D + +__all__ = ['SeparableConv2D', 'MobileResnetBlock', 'ResnetBlock'] + + +class SeparableConv2D(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + stride=1, + padding=0, + norm_layer=InstanceNorm, + use_bias=True, + scale_factor=1, + stddev=0.02, + use_cudnn=use_cudnn): + super(SeparableConv2D, self).__init__() + + self.conv = fluid.dygraph.LayerList([ + Conv2D( + num_channels=num_channels, + num_filters=num_channels * scale_factor, + filter_size=filter_size, + stride=stride, + padding=padding, + use_cudnn=False, + groups=num_channels, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=stddev)), + bias_attr=use_bias) + ]) + + self.conv.extend([norm_layer(num_channels * scale_factor)]) + + self.conv.extend([ + Conv2D( + num_channels=num_channels * scale_factor, + num_filters=num_filters, + filter_size=1, + stride=1, + use_cudnn=use_cudnn, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=stddev)), + bias_attr=use_bias) + ]) + + def forward(self, inputs): + for sublayer in self.conv: + inputs = sublayer(inputs) + return inputs + + +class MobileResnetBlock(fluid.dygraph.Layer): + def __init__(self, in_c, out_c, padding_type, norm_layer, dropout_rate, + use_bias): + super(MobileResnetBlock, self).__init__() + self.padding_type = padding_type + self.dropout_rate = dropout_rate + self.conv_block = fluid.dygraph.LayerList([]) + + p = 0 + if self.padding_type == 'reflect': + self.conv_block.extend( + [Pad2D( + paddings=[1, 1, 1, 1], mode='reflect')]) + elif self.padding_type == 'replicate': + self.conv_block.extend( + [Pad2D( + inputs, paddings=[1, 1, 1, 1], mode='edge')]) + elif self.padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % + self.padding_type) + + self.conv_block.extend([ + SeparableConv2D( + num_channels=in_c, + num_filters=out_c, + filter_size=3, + padding=p, + stride=1), norm_layer(out_c), ReLU() + ]) + + self.conv_block.extend([Dropout(p=self.dropout_rate)]) + + if self.padding_type == 'reflect': + self.conv_block.extend( + [Pad2D( + paddings=[1, 1, 1, 1], mode='reflect')]) + elif self.padding_type == 'replicate': + self.conv_block.extend( + [Pad2D( + inputs, paddings=[1, 1, 1, 1], mode='edge')]) + elif self.padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % + self.padding_type) + + self.conv_block.extend([ + SeparableConv2D( + num_channels=out_c, + num_filters=in_c, + filter_size=3, + padding=p, + stride=1), norm_layer(in_c) + ]) + + def forward(self, inputs): + y = inputs + for sublayer in self.conv_block: + y = sublayer(y) + out = inputs + y + return out + + +class ResnetBlock(fluid.dygraph.Layer): + def __init__(self, + dim, + padding_type, + norm_layer, + dropout_rate, + use_bias=False): + super(ResnetBlock, self).__init__() + + self.conv_block = fluid.dygraph.LayerList([]) + p = 0 + if padding_type == 'reflect': + self.conv_block.extend( + [Pad2D( + paddings=[1, 1, 1, 1], mode='reflect')]) + elif padding_type == 'replicate': + self.conv_block.extend([Pad2D(paddings=[1, 1, 1, 1], mode='edge')]) + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % + padding_type) + + self.conv_block.extend([ + Conv2D( + dim, dim, filter_size=3, padding=p, bias_attr=use_bias), + norm_layer(dim), ReLU() + ]) + self.conv_block.extend([Dropout(dropout_rate)]) + + p = 0 + if padding_type == 'reflect': + self.conv_block.extend( + [Pad2D( + paddings=[1, 1, 1, 1], mode='reflect')]) + elif padding_type == 'replicate': + self.conv_block.extend([Pad2D(paddings=[1, 1, 1, 1], mode='edge')]) + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % + padding_type) + + self.conv_block.extend([ + Conv2D( + dim, dim, filter_size=3, padding=p, bias_attr=use_bias), + norm_layer(dim) + ]) + + def forward(self, inputs): + y = inputs + for sublayer in self.conv_block: + y = sublayer(y) + return y + inputs +