From f3531c7baae81c1b868ee0ad3df320b1b10378f6 Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Mon, 18 Apr 2022 14:10:41 +0800 Subject: [PATCH] [infrt] add efficientnet model (#41507) --- .../tests/models/efficientnet-b4/model.py | 26 ++ .../models/efficientnet-b4/net/__init__.py | 15 + .../efficientnet-b4/net/efficientnet.py | 284 +++++++++++++ .../tests/models/efficientnet-b4/net/utils.py | 385 ++++++++++++++++++ paddle/scripts/infrt_build.sh | 13 +- 5 files changed, 718 insertions(+), 5 deletions(-) create mode 100644 paddle/infrt/tests/models/efficientnet-b4/model.py create mode 100644 paddle/infrt/tests/models/efficientnet-b4/net/__init__.py create mode 100644 paddle/infrt/tests/models/efficientnet-b4/net/efficientnet.py create mode 100644 paddle/infrt/tests/models/efficientnet-b4/net/utils.py diff --git a/paddle/infrt/tests/models/efficientnet-b4/model.py b/paddle/infrt/tests/models/efficientnet-b4/model.py new file mode 100644 index 00000000000..c660c3a4674 --- /dev/null +++ b/paddle/infrt/tests/models/efficientnet-b4/model.py @@ -0,0 +1,26 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# url: https://aistudio.baidu.com/aistudio/projectdetail/3756986?forkThirdPart=1 +from net import EfficientNet +from paddle.jit import to_static +from paddle.static import InputSpec +import paddle +import sys + +model = EfficientNet.from_name('efficientnet-b4') +net = to_static( + model, input_spec=[InputSpec( + shape=[None, 3, 256, 256], name='x')]) +paddle.jit.save(net, sys.argv[1]) diff --git a/paddle/infrt/tests/models/efficientnet-b4/net/__init__.py b/paddle/infrt/tests/models/efficientnet-b4/net/__init__.py new file mode 100644 index 00000000000..d4e557829ae --- /dev/null +++ b/paddle/infrt/tests/models/efficientnet-b4/net/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .efficientnet import EfficientNet diff --git a/paddle/infrt/tests/models/efficientnet-b4/net/efficientnet.py b/paddle/infrt/tests/models/efficientnet-b4/net/efficientnet.py new file mode 100644 index 00000000000..a9956fcdc88 --- /dev/null +++ b/paddle/infrt/tests/models/efficientnet-b4/net/efficientnet.py @@ -0,0 +1,284 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .utils import (round_filters, round_repeats, drop_connect, + get_same_padding_conv2d, get_model_params, + efficientnet_params, load_pretrained_weights) + + +class MBConvBlock(nn.Layer): + """ + Mobile Inverted Residual Bottleneck Block + + Args: + block_args (namedtuple): BlockArgs, see above + global_params (namedtuple): GlobalParam, see above + + Attributes: + has_se (bool): Whether the block contains a Squeeze and Excitation layer. + """ + + def __init__(self, block_args, global_params): + super().__init__() + self._block_args = block_args + self._bn_mom = global_params.batch_norm_momentum + self._bn_eps = global_params.batch_norm_epsilon + self.has_se = (self._block_args.se_ratio is not None) and ( + 0 < self._block_args.se_ratio <= 1) + self.id_skip = block_args.id_skip # skip connection and drop connect + + # Get static or dynamic convolution depending on image size + Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) + + # Expansion phase + inp = self._block_args.input_filters # number of input channels + oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels + if self._block_args.expand_ratio != 1: + self._expand_conv = Conv2d( + in_channels=inp, + out_channels=oup, + kernel_size=1, + bias_attr=False) + self._bn0 = nn.BatchNorm2D( + num_features=oup, momentum=self._bn_mom, epsilon=self._bn_eps) + + # Depthwise convolution phase + k = self._block_args.kernel_size + s = self._block_args.stride + self._depthwise_conv = Conv2d( + in_channels=oup, + out_channels=oup, + groups=oup, # groups makes it depthwise + kernel_size=k, + stride=s, + bias_attr=False) + self._bn1 = nn.BatchNorm2D( + num_features=oup, momentum=self._bn_mom, epsilon=self._bn_eps) + + # Squeeze and Excitation layer, if desired + if self.has_se: + num_squeezed_channels = max(1, + int(self._block_args.input_filters * + self._block_args.se_ratio)) + self._se_reduce = Conv2d( + in_channels=oup, + out_channels=num_squeezed_channels, + kernel_size=1) + self._se_expand = Conv2d( + in_channels=num_squeezed_channels, + out_channels=oup, + kernel_size=1) + + # Output phase + final_oup = self._block_args.output_filters + self._project_conv = Conv2d( + in_channels=oup, + out_channels=final_oup, + kernel_size=1, + bias_attr=False) + self._bn2 = nn.BatchNorm2D( + num_features=final_oup, momentum=self._bn_mom, epsilon=self._bn_eps) + self._swish = nn.Hardswish() + + def forward(self, inputs, drop_connect_rate=None): + """ + :param inputs: input tensor + :param drop_connect_rate: drop connect rate (float, between 0 and 1) + :return: output of block + """ + + # Expansion and Depthwise Convolution + x = inputs + if self._block_args.expand_ratio != 1: + x = self._swish(self._bn0(self._expand_conv(inputs))) + x = self._swish(self._bn1(self._depthwise_conv(x))) + + # Squeeze and Excitation + if self.has_se: + x_squeezed = F.adaptive_avg_pool2d(x, 1) + x_squeezed = self._se_expand( + self._swish(self._se_reduce(x_squeezed))) + x = F.sigmoid(x_squeezed) * x + + x = self._bn2(self._project_conv(x)) + + # Skip connection and drop connect + input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters + if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: + if drop_connect_rate: + x = drop_connect( + x, prob=drop_connect_rate, training=self.training) + x = x + inputs # skip connection + return x + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export)""" + self._swish = nn.Hardswish() if memory_efficient else nn.Swish() + + +class EfficientNet(nn.Layer): + """ + An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods + + Args: + blocks_args (list): A list of BlockArgs to construct blocks + global_params (namedtuple): A set of GlobalParams shared between blocks + + Example: + model = EfficientNet.from_pretrained('efficientnet-b0') + + """ + + def __init__(self, blocks_args=None, global_params=None): + super().__init__() + assert isinstance(blocks_args, list), 'blocks_args should be a list' + assert len(blocks_args) > 0, 'block args must be greater than 0' + self._global_params = global_params + self._blocks_args = blocks_args + + # Get static or dynamic convolution depending on image size + Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) + + # Batch norm parameters + bn_mom = self._global_params.batch_norm_momentum + bn_eps = self._global_params.batch_norm_epsilon + + # Stem + in_channels = 3 # rgb + out_channels = round_filters( + 32, self._global_params) # number of output channels + self._conv_stem = Conv2d( + in_channels, out_channels, kernel_size=3, stride=2, bias_attr=False) + self._bn0 = nn.BatchNorm2D( + num_features=out_channels, momentum=bn_mom, epsilon=bn_eps) + + # Build blocks + self._blocks = nn.LayerList([]) + for block_args in self._blocks_args: + + # Update block input and output filters based on depth multiplier. + block_args = block_args._replace( + input_filters=round_filters(block_args.input_filters, + self._global_params), + output_filters=round_filters(block_args.output_filters, + self._global_params), + num_repeat=round_repeats(block_args.num_repeat, + self._global_params)) + + # The first block needs to take care of stride and filter size increase. + self._blocks.append(MBConvBlock(block_args, self._global_params)) + if block_args.num_repeat > 1: + block_args = block_args._replace( + input_filters=block_args.output_filters, stride=1) + for _ in range(block_args.num_repeat - 1): + self._blocks.append( + MBConvBlock(block_args, self._global_params)) + + # Head + in_channels = block_args.output_filters # output of final block + out_channels = round_filters(1280, self._global_params) + self._conv_head = Conv2d( + in_channels, out_channels, kernel_size=1, bias_attr=False) + self._bn1 = nn.BatchNorm2D( + num_features=out_channels, momentum=bn_mom, epsilon=bn_eps) + + # Final linear layer + self._avg_pooling = nn.AdaptiveAvgPool2D(1) + self._dropout = nn.Dropout(self._global_params.dropout_rate) + self._fc = nn.Linear(out_channels, self._global_params.num_classes) + self._swish = nn.Hardswish() + + def set_swish(self, memory_efficient=True): + """Sets swish function as memory efficient (for training) or standard (for export)""" + self._swish = nn.Hardswish() if memory_efficient else nn.Swish() + for block in self._blocks: + block.set_swish(memory_efficient) + + def extract_features(self, inputs): + """ Returns output of the final convolution layer """ + + # Stem + x = self._swish(self._bn0(self._conv_stem(inputs))) + + # Blocks + for idx, block in enumerate(self._blocks): + drop_connect_rate = self._global_params.drop_connect_rate + if drop_connect_rate: + drop_connect_rate *= float(idx) / len(self._blocks) + x = block(x, drop_connect_rate=drop_connect_rate) + + # Head + x = self._swish(self._bn1(self._conv_head(x))) + + return x + + def forward(self, inputs): + """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ + bs = inputs.shape[0] + # Convolution layers + x = self.extract_features(inputs) + + # Pooling and final linear layer + x = self._avg_pooling(x) + x = paddle.reshape(x, (bs, -1)) + x = self._dropout(x) + x = self._fc(x) + return x + + @classmethod + def from_name(cls, model_name, override_params=None): + cls._check_model_name_is_valid(model_name) + blocks_args, global_params = get_model_params(model_name, + override_params) + return cls(blocks_args, global_params) + + @classmethod + def from_pretrained(cls, + model_name, + advprop=False, + num_classes=1000, + in_channels=3): + model = cls.from_name( + model_name, override_params={'num_classes': num_classes}) + load_pretrained_weights( + model, model_name, load_fc=(num_classes == 1000), advprop=advprop) + if in_channels != 3: + Conv2d = get_same_padding_conv2d( + image_size=model._global_params.image_size) + out_channels = round_filters(32, model._global_params) + model._conv_stem = Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + bias_attr=False) + return model + + @classmethod + def get_image_size(cls, model_name): + cls._check_model_name_is_valid(model_name) + _, _, res, _ = efficientnet_params(model_name) + return res + + @classmethod + def _check_model_name_is_valid(cls, model_name): + """ Validates model name. """ + valid_models = ['efficientnet-b' + str(i) for i in range(9)] + if model_name not in valid_models: + raise ValueError('model_name should be one of: ' + ', '.join( + valid_models)) diff --git a/paddle/infrt/tests/models/efficientnet-b4/net/utils.py b/paddle/infrt/tests/models/efficientnet-b4/net/utils.py new file mode 100644 index 00000000000..3bf8b4eb730 --- /dev/null +++ b/paddle/infrt/tests/models/efficientnet-b4/net/utils.py @@ -0,0 +1,385 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +import math +from functools import partial +import collections + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +# Parameters for the entire model (stem, all blocks, and head) +GlobalParams = collections.namedtuple('GlobalParams', [ + 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'num_classes', + 'width_coefficient', 'depth_coefficient', 'depth_divisor', 'min_depth', + 'drop_connect_rate', 'image_size' +]) + +# Parameters for an individual model block +BlockArgs = collections.namedtuple('BlockArgs', [ + 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', + 'expand_ratio', 'id_skip', 'stride', 'se_ratio' +]) + +# Change namedtuple defaults +GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields) +BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields) + + +def round_filters(filters, global_params): + """ Calculate and round number of filters based on depth multiplier. """ + multiplier = global_params.width_coefficient + if not multiplier: + return filters + divisor = global_params.depth_divisor + min_depth = global_params.min_depth + filters *= multiplier + min_depth = min_depth or divisor + new_filters = max(min_depth, + int(filters + divisor / 2) // divisor * divisor) + if new_filters < 0.9 * filters: # prevent rounding by more than 10% + new_filters += divisor + return int(new_filters) + + +def round_repeats(repeats, global_params): + """ Round number of filters based on depth multiplier. """ + multiplier = global_params.depth_coefficient + if not multiplier: + return repeats + return int(math.ceil(multiplier * repeats)) + + +def drop_connect(inputs, prob, training): + """Drop input connection""" + if not training: + return inputs + keep_prob = 1.0 - prob + inputs_shape = paddle.shape(inputs) + random_tensor = keep_prob + paddle.rand(shape=[inputs_shape[0], 1, 1, 1]) + binary_tensor = paddle.floor(random_tensor) + output = inputs / keep_prob * binary_tensor + return output + + +def get_same_padding_conv2d(image_size=None): + """ Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. """ + if image_size is None: + return Conv2dDynamicSamePadding + else: + return partial(Conv2dStaticSamePadding, image_size=image_size) + + +class Conv2dDynamicSamePadding(nn.Conv2D): + """ 2D Convolutions like TensorFlow, for a dynamic image size """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias_attr=None): + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + 0, + dilation, + groups, + bias_attr=bias_attr) + self.stride = self._stride if len( + self._stride) == 2 else [self._stride[0]] * 2 + + def forward(self, x): + ih, iw = x.shape[-2:] + kh, kw = self.weight.shape[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + + (kh - 1) * self._dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + + (kw - 1) * self._dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + x = F.pad(x, [ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 + ]) + return F.conv2d(x, self.weight, self.bias, self.stride, self._padding, + self._dilation, self._groups) + + +class Conv2dStaticSamePadding(nn.Conv2D): + """ 2D Convolutions like TensorFlow, for a fixed image size""" + + def __init__(self, + in_channels, + out_channels, + kernel_size, + image_size=None, + **kwargs): + if 'stride' in kwargs and isinstance(kwargs['stride'], list): + kwargs['stride'] = kwargs['stride'][0] + super().__init__(in_channels, out_channels, kernel_size, **kwargs) + self.stride = self._stride if len( + self._stride) == 2 else [self._stride[0]] * 2 + + # Calculate padding based on image size and save it + assert image_size is not None + ih, iw = image_size if type( + image_size) == list else [image_size, image_size] + kh, kw = self.weight.shape[-2:] + sh, sw = self.stride + oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) + pad_h = max((oh - 1) * self.stride[0] + + (kh - 1) * self._dilation[0] + 1 - ih, 0) + pad_w = max((ow - 1) * self.stride[1] + + (kw - 1) * self._dilation[1] + 1 - iw, 0) + if pad_h > 0 or pad_w > 0: + self.static_padding = nn.Pad2D([ + pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2 + ]) + else: + self.static_padding = Identity() + + def forward(self, x): + x = self.static_padding(x) + x = F.conv2d(x, self.weight, self.bias, self.stride, self._padding, + self._dilation, self._groups) + return x + + +class Identity(nn.Layer): + def __init__(self, ): + super().__init__() + + def forward(self, x): + return x + + +def efficientnet_params(model_name): + """ Map EfficientNet model name to parameter coefficients. """ + params_dict = { + # Coefficients: width,depth,resolution,dropout + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + 'efficientnet-l2': (4.3, 5.3, 800, 0.5), + } + return params_dict[model_name] + + +class BlockDecoder(object): + """ Block Decoder for readability, straight from the official TensorFlow repository """ + + @staticmethod + def _decode_block_string(block_string): + """ Gets a block through a string notation of arguments. """ + assert isinstance(block_string, str) + + ops = block_string.split('_') + options = {} + for op in ops: + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # Check stride + assert (('s' in options and len(options['s']) == 1) or + (len(options['s']) == 2 and options['s'][0] == options['s'][1])) + + return BlockArgs( + kernel_size=int(options['k']), + num_repeat=int(options['r']), + input_filters=int(options['i']), + output_filters=int(options['o']), + expand_ratio=int(options['e']), + id_skip=('noskip' not in block_string), + se_ratio=float(options['se']) if 'se' in options else None, + stride=[int(options['s'][0])]) + + @staticmethod + def _encode_block_string(block): + """Encodes a block to a string.""" + args = [ + 'r%d' % block.num_repeat, 'k%d' % block.kernel_size, 's%d%d' % + (block.strides[0], block.strides[1]), 'e%s' % block.expand_ratio, + 'i%d' % block.input_filters, 'o%d' % block.output_filters + ] + if 0 < block.se_ratio <= 1: + args.append('se%s' % block.se_ratio) + if block.id_skip is False: + args.append('noskip') + return '_'.join(args) + + @staticmethod + def decode(string_list): + """ + Decodes a list of string notations to specify blocks inside the network. + + :param string_list: a list of strings, each string is a notation of block + :return: a list of BlockArgs namedtuples of block args + """ + assert isinstance(string_list, list) + blocks_args = [] + for block_string in string_list: + blocks_args.append(BlockDecoder._decode_block_string(block_string)) + return blocks_args + + @staticmethod + def encode(blocks_args): + """ + Encodes a list of BlockArgs to a list of strings. + + :param blocks_args: a list of BlockArgs namedtuples of block args + :return: a list of strings, each string is a notation of block + """ + block_strings = [] + for block in blocks_args: + block_strings.append(BlockDecoder._encode_block_string(block)) + return block_strings + + +def efficientnet(width_coefficient=None, + depth_coefficient=None, + dropout_rate=0.2, + drop_connect_rate=0.2, + image_size=None, + num_classes=1000): + """ Get block arguments according to parameter and coefficients. """ + blocks_args = [ + 'r1_k3_s11_e1_i32_o16_se0.25', + 'r2_k3_s22_e6_i16_o24_se0.25', + 'r2_k5_s22_e6_i24_o40_se0.25', + 'r3_k3_s22_e6_i40_o80_se0.25', + 'r3_k5_s11_e6_i80_o112_se0.25', + 'r4_k5_s22_e6_i112_o192_se0.25', + 'r1_k3_s11_e6_i192_o320_se0.25', + ] + blocks_args = BlockDecoder.decode(blocks_args) + + global_params = GlobalParams( + batch_norm_momentum=0.99, + batch_norm_epsilon=1e-3, + dropout_rate=dropout_rate, + drop_connect_rate=drop_connect_rate, + num_classes=num_classes, + width_coefficient=width_coefficient, + depth_coefficient=depth_coefficient, + depth_divisor=8, + min_depth=None, + image_size=image_size, ) + + return blocks_args, global_params + + +def get_model_params(model_name, override_params): + """ Get the block args and global params for a given model """ + if model_name.startswith('efficientnet'): + w, d, s, p = efficientnet_params(model_name) + blocks_args, global_params = efficientnet( + width_coefficient=w, + depth_coefficient=d, + dropout_rate=p, + image_size=s) + else: + raise NotImplementedError('model name is not pre-defined: %s' % + model_name) + if override_params: + global_params = global_params._replace(**override_params) + return blocks_args, global_params + + +url_map = { + 'efficientnet-b0': + '/home/aistudio/data/weights/efficientnet-b0-355c32eb.pdparams', + 'efficientnet-b1': + '/home/aistudio/data/weights/efficientnet-b1-f1951068.pdparams', + 'efficientnet-b2': + '/home/aistudio/data/weights/efficientnet-b2-8bb594d6.pdparams', + 'efficientnet-b3': + '/home/aistudio/data/weights/efficientnet-b3-5fb5a3c3.pdparams', + 'efficientnet-b4': + '/home/aistudio/data/weights/efficientnet-b4-6ed6700e.pdparams', + 'efficientnet-b5': + '/home/aistudio/data/weights/efficientnet-b5-b6417697.pdparams', + 'efficientnet-b6': + '/home/aistudio/data/weights/efficientnet-b6-c76e70fd.pdparams', + 'efficientnet-b7': + '/home/aistudio/data/weights/efficientnet-b7-dcc49843.pdparams', +} + +url_map_advprop = { + 'efficientnet-b0': + '/home/aistudio/data/weights/adv-efficientnet-b0-b64d5a18.pdparams', + 'efficientnet-b1': + '/home/aistudio/data/weights/adv-efficientnet-b1-0f3ce85a.pdparams', + 'efficientnet-b2': + '/home/aistudio/data/weights/adv-efficientnet-b2-6e9d97e5.pdparams', + 'efficientnet-b3': + '/home/aistudio/data/weights/adv-efficientnet-b3-cdd7c0f4.pdparams', + 'efficientnet-b4': + '/home/aistudio/data/weights/adv-efficientnet-b4-44fb3a87.pdparams', + 'efficientnet-b5': + '/home/aistudio/data/weights/adv-efficientnet-b5-86493f6b.pdparams', + 'efficientnet-b6': + '/home/aistudio/data/weights/adv-efficientnet-b6-ac80338e.pdparams', + 'efficientnet-b7': + '/home/aistudio/data/weights/adv-efficientnet-b7-4652b6dd.pdparams', + 'efficientnet-b8': + '/home/aistudio/data/weights/adv-efficientnet-b8-22a8fe65.pdparams', +} + + +def load_pretrained_weights(model, + model_name, + weights_path=None, + load_fc=True, + advprop=False): + """Loads pretrained weights from weights path or download using url. + Args: + model (Module): The whole model of efficientnet. + model_name (str): Model name of efficientnet. + weights_path (None or str): + str: path to pretrained weights file on the local disk. + None: use pretrained weights downloaded from the Internet. + load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. + advprop (bool): Whether to load pretrained weights + trained with advprop (valid when weights_path is None). + """ + + # AutoAugment or Advprop (different preprocessing) + url_map_ = url_map_advprop if advprop else url_map + state_dict = paddle.load(url_map_[model_name]) + + if load_fc: + model.set_state_dict(state_dict) + else: + state_dict.pop('_fc.weight') + state_dict.pop('_fc.bias') + model.set_state_dict(state_dict) + + print('Loaded pretrained weights for {}'.format(model_name)) diff --git a/paddle/scripts/infrt_build.sh b/paddle/scripts/infrt_build.sh index 6634f5396ac..2756e3b3211 100755 --- a/paddle/scripts/infrt_build.sh +++ b/paddle/scripts/infrt_build.sh @@ -44,11 +44,6 @@ function update_pd_ops() { cd ${PADDLE_ROOT}/tools/infrt/ python3 generate_pd_op_dialect_from_paddle_op_maker.py python3 generate_phi_kernel_dialect.py - # generate test model - cd ${PADDLE_ROOT} - mkdir -p ${PADDLE_ROOT}/build/models - python3 paddle/infrt/tests/models/abs_model.py ${PADDLE_ROOT}/build/paddle/infrt/tests/abs - python3 paddle/infrt/tests/models/resnet50_model.py ${PADDLE_ROOT}/build/models/resnet50/model } function init() { @@ -114,6 +109,14 @@ function create_fake_models() { # create multi_fc model, this will generate "multi_fc_model" python3 -m pip uninstall -y paddlepaddle python3 -m pip install *whl + + # generate test model + cd ${PADDLE_ROOT} + mkdir -p ${PADDLE_ROOT}/build/models + python3 paddle/infrt/tests/models/abs_model.py ${PADDLE_ROOT}/build/paddle/infrt/tests/abs + python3 paddle/infrt/tests/models/resnet50_model.py ${PADDLE_ROOT}/build/models/resnet50/model + python3 paddle/infrt/tests/models/efficientnet-b4/model.py ${PADDLE_ROOT}/build/models/efficientnet-b4/model + cd ${PADDLE_ROOT}/build python3 ${PADDLE_ROOT}/tools/infrt/fake_models/multi_fc.py python3 ${PADDLE_ROOT}/paddle/infrt/tests/models/linear.py -- GitLab