未验证 提交 f3531c7b 编写于 作者: H huzhiqiang 提交者: GitHub

[infrt] add efficientnet model (#41507)

上级 037c8099
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# url: https://aistudio.baidu.com/aistudio/projectdetail/3756986?forkThirdPart=1
from net import EfficientNet
from paddle.jit import to_static
from paddle.static import InputSpec
import paddle
import sys
model = EfficientNet.from_name('efficientnet-b4')
net = to_static(
model, input_spec=[InputSpec(
shape=[None, 3, 256, 256], name='x')])
paddle.jit.save(net, sys.argv[1])
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .efficientnet import EfficientNet
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .utils import (round_filters, round_repeats, drop_connect,
get_same_padding_conv2d, get_model_params,
efficientnet_params, load_pretrained_weights)
class MBConvBlock(nn.Layer):
"""
Mobile Inverted Residual Bottleneck Block
Args:
block_args (namedtuple): BlockArgs, see above
global_params (namedtuple): GlobalParam, see above
Attributes:
has_se (bool): Whether the block contains a Squeeze and Excitation layer.
"""
def __init__(self, block_args, global_params):
super().__init__()
self._block_args = block_args
self._bn_mom = global_params.batch_norm_momentum
self._bn_eps = global_params.batch_norm_epsilon
self.has_se = (self._block_args.se_ratio is not None) and (
0 < self._block_args.se_ratio <= 1)
self.id_skip = block_args.id_skip # skip connection and drop connect
# Get static or dynamic convolution depending on image size
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
# Expansion phase
inp = self._block_args.input_filters # number of input channels
oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
if self._block_args.expand_ratio != 1:
self._expand_conv = Conv2d(
in_channels=inp,
out_channels=oup,
kernel_size=1,
bias_attr=False)
self._bn0 = nn.BatchNorm2D(
num_features=oup, momentum=self._bn_mom, epsilon=self._bn_eps)
# Depthwise convolution phase
k = self._block_args.kernel_size
s = self._block_args.stride
self._depthwise_conv = Conv2d(
in_channels=oup,
out_channels=oup,
groups=oup, # groups makes it depthwise
kernel_size=k,
stride=s,
bias_attr=False)
self._bn1 = nn.BatchNorm2D(
num_features=oup, momentum=self._bn_mom, epsilon=self._bn_eps)
# Squeeze and Excitation layer, if desired
if self.has_se:
num_squeezed_channels = max(1,
int(self._block_args.input_filters *
self._block_args.se_ratio))
self._se_reduce = Conv2d(
in_channels=oup,
out_channels=num_squeezed_channels,
kernel_size=1)
self._se_expand = Conv2d(
in_channels=num_squeezed_channels,
out_channels=oup,
kernel_size=1)
# Output phase
final_oup = self._block_args.output_filters
self._project_conv = Conv2d(
in_channels=oup,
out_channels=final_oup,
kernel_size=1,
bias_attr=False)
self._bn2 = nn.BatchNorm2D(
num_features=final_oup, momentum=self._bn_mom, epsilon=self._bn_eps)
self._swish = nn.Hardswish()
def forward(self, inputs, drop_connect_rate=None):
"""
:param inputs: input tensor
:param drop_connect_rate: drop connect rate (float, between 0 and 1)
:return: output of block
"""
# Expansion and Depthwise Convolution
x = inputs
if self._block_args.expand_ratio != 1:
x = self._swish(self._bn0(self._expand_conv(inputs)))
x = self._swish(self._bn1(self._depthwise_conv(x)))
# Squeeze and Excitation
if self.has_se:
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self._se_expand(
self._swish(self._se_reduce(x_squeezed)))
x = F.sigmoid(x_squeezed) * x
x = self._bn2(self._project_conv(x))
# Skip connection and drop connect
input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
if drop_connect_rate:
x = drop_connect(
x, prob=drop_connect_rate, training=self.training)
x = x + inputs # skip connection
return x
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export)"""
self._swish = nn.Hardswish() if memory_efficient else nn.Swish()
class EfficientNet(nn.Layer):
"""
An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods
Args:
blocks_args (list): A list of BlockArgs to construct blocks
global_params (namedtuple): A set of GlobalParams shared between blocks
Example:
model = EfficientNet.from_pretrained('efficientnet-b0')
"""
def __init__(self, blocks_args=None, global_params=None):
super().__init__()
assert isinstance(blocks_args, list), 'blocks_args should be a list'
assert len(blocks_args) > 0, 'block args must be greater than 0'
self._global_params = global_params
self._blocks_args = blocks_args
# Get static or dynamic convolution depending on image size
Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
# Batch norm parameters
bn_mom = self._global_params.batch_norm_momentum
bn_eps = self._global_params.batch_norm_epsilon
# Stem
in_channels = 3 # rgb
out_channels = round_filters(
32, self._global_params) # number of output channels
self._conv_stem = Conv2d(
in_channels, out_channels, kernel_size=3, stride=2, bias_attr=False)
self._bn0 = nn.BatchNorm2D(
num_features=out_channels, momentum=bn_mom, epsilon=bn_eps)
# Build blocks
self._blocks = nn.LayerList([])
for block_args in self._blocks_args:
# Update block input and output filters based on depth multiplier.
block_args = block_args._replace(
input_filters=round_filters(block_args.input_filters,
self._global_params),
output_filters=round_filters(block_args.output_filters,
self._global_params),
num_repeat=round_repeats(block_args.num_repeat,
self._global_params))
# The first block needs to take care of stride and filter size increase.
self._blocks.append(MBConvBlock(block_args, self._global_params))
if block_args.num_repeat > 1:
block_args = block_args._replace(
input_filters=block_args.output_filters, stride=1)
for _ in range(block_args.num_repeat - 1):
self._blocks.append(
MBConvBlock(block_args, self._global_params))
# Head
in_channels = block_args.output_filters # output of final block
out_channels = round_filters(1280, self._global_params)
self._conv_head = Conv2d(
in_channels, out_channels, kernel_size=1, bias_attr=False)
self._bn1 = nn.BatchNorm2D(
num_features=out_channels, momentum=bn_mom, epsilon=bn_eps)
# Final linear layer
self._avg_pooling = nn.AdaptiveAvgPool2D(1)
self._dropout = nn.Dropout(self._global_params.dropout_rate)
self._fc = nn.Linear(out_channels, self._global_params.num_classes)
self._swish = nn.Hardswish()
def set_swish(self, memory_efficient=True):
"""Sets swish function as memory efficient (for training) or standard (for export)"""
self._swish = nn.Hardswish() if memory_efficient else nn.Swish()
for block in self._blocks:
block.set_swish(memory_efficient)
def extract_features(self, inputs):
""" Returns output of the final convolution layer """
# Stem
x = self._swish(self._bn0(self._conv_stem(inputs)))
# Blocks
for idx, block in enumerate(self._blocks):
drop_connect_rate = self._global_params.drop_connect_rate
if drop_connect_rate:
drop_connect_rate *= float(idx) / len(self._blocks)
x = block(x, drop_connect_rate=drop_connect_rate)
# Head
x = self._swish(self._bn1(self._conv_head(x)))
return x
def forward(self, inputs):
""" Calls extract_features to extract features, applies final linear layer, and returns logits. """
bs = inputs.shape[0]
# Convolution layers
x = self.extract_features(inputs)
# Pooling and final linear layer
x = self._avg_pooling(x)
x = paddle.reshape(x, (bs, -1))
x = self._dropout(x)
x = self._fc(x)
return x
@classmethod
def from_name(cls, model_name, override_params=None):
cls._check_model_name_is_valid(model_name)
blocks_args, global_params = get_model_params(model_name,
override_params)
return cls(blocks_args, global_params)
@classmethod
def from_pretrained(cls,
model_name,
advprop=False,
num_classes=1000,
in_channels=3):
model = cls.from_name(
model_name, override_params={'num_classes': num_classes})
load_pretrained_weights(
model, model_name, load_fc=(num_classes == 1000), advprop=advprop)
if in_channels != 3:
Conv2d = get_same_padding_conv2d(
image_size=model._global_params.image_size)
out_channels = round_filters(32, model._global_params)
model._conv_stem = Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=2,
bias_attr=False)
return model
@classmethod
def get_image_size(cls, model_name):
cls._check_model_name_is_valid(model_name)
_, _, res, _ = efficientnet_params(model_name)
return res
@classmethod
def _check_model_name_is_valid(cls, model_name):
""" Validates model name. """
valid_models = ['efficientnet-b' + str(i) for i in range(9)]
if model_name not in valid_models:
raise ValueError('model_name should be one of: ' + ', '.join(
valid_models))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import math
from functools import partial
import collections
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
# Parameters for the entire model (stem, all blocks, and head)
GlobalParams = collections.namedtuple('GlobalParams', [
'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'num_classes',
'width_coefficient', 'depth_coefficient', 'depth_divisor', 'min_depth',
'drop_connect_rate', 'image_size'
])
# Parameters for an individual model block
BlockArgs = collections.namedtuple('BlockArgs', [
'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
'expand_ratio', 'id_skip', 'stride', 'se_ratio'
])
# Change namedtuple defaults
GlobalParams.__new__.__defaults__ = (None, ) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None, ) * len(BlockArgs._fields)
def round_filters(filters, global_params):
""" Calculate and round number of filters based on depth multiplier. """
multiplier = global_params.width_coefficient
if not multiplier:
return filters
divisor = global_params.depth_divisor
min_depth = global_params.min_depth
filters *= multiplier
min_depth = min_depth or divisor
new_filters = max(min_depth,
int(filters + divisor / 2) // divisor * divisor)
if new_filters < 0.9 * filters: # prevent rounding by more than 10%
new_filters += divisor
return int(new_filters)
def round_repeats(repeats, global_params):
""" Round number of filters based on depth multiplier. """
multiplier = global_params.depth_coefficient
if not multiplier:
return repeats
return int(math.ceil(multiplier * repeats))
def drop_connect(inputs, prob, training):
"""Drop input connection"""
if not training:
return inputs
keep_prob = 1.0 - prob
inputs_shape = paddle.shape(inputs)
random_tensor = keep_prob + paddle.rand(shape=[inputs_shape[0], 1, 1, 1])
binary_tensor = paddle.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
return output
def get_same_padding_conv2d(image_size=None):
""" Chooses static padding if you have specified an image size, and dynamic padding otherwise.
Static padding is necessary for ONNX exporting of models. """
if image_size is None:
return Conv2dDynamicSamePadding
else:
return partial(Conv2dStaticSamePadding, image_size=image_size)
class Conv2dDynamicSamePadding(nn.Conv2D):
""" 2D Convolutions like TensorFlow, for a dynamic image size """
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
dilation=1,
groups=1,
bias_attr=None):
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
0,
dilation,
groups,
bias_attr=bias_attr)
self.stride = self._stride if len(
self._stride) == 2 else [self._stride[0]] * 2
def forward(self, x):
ih, iw = x.shape[-2:]
kh, kw = self.weight.shape[-2:]
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
pad_h = max((oh - 1) * self.stride[0] +
(kh - 1) * self._dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] +
(kw - 1) * self._dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
x = F.pad(x, [
pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
])
return F.conv2d(x, self.weight, self.bias, self.stride, self._padding,
self._dilation, self._groups)
class Conv2dStaticSamePadding(nn.Conv2D):
""" 2D Convolutions like TensorFlow, for a fixed image size"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
image_size=None,
**kwargs):
if 'stride' in kwargs and isinstance(kwargs['stride'], list):
kwargs['stride'] = kwargs['stride'][0]
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.stride = self._stride if len(
self._stride) == 2 else [self._stride[0]] * 2
# Calculate padding based on image size and save it
assert image_size is not None
ih, iw = image_size if type(
image_size) == list else [image_size, image_size]
kh, kw = self.weight.shape[-2:]
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
pad_h = max((oh - 1) * self.stride[0] +
(kh - 1) * self._dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] +
(kw - 1) * self._dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
self.static_padding = nn.Pad2D([
pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
])
else:
self.static_padding = Identity()
def forward(self, x):
x = self.static_padding(x)
x = F.conv2d(x, self.weight, self.bias, self.stride, self._padding,
self._dilation, self._groups)
return x
class Identity(nn.Layer):
def __init__(self, ):
super().__init__()
def forward(self, x):
return x
def efficientnet_params(model_name):
""" Map EfficientNet model name to parameter coefficients. """
params_dict = {
# Coefficients: width,depth,resolution,dropout
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
'efficientnet-b8': (2.2, 3.6, 672, 0.5),
'efficientnet-l2': (4.3, 5.3, 800, 0.5),
}
return params_dict[model_name]
class BlockDecoder(object):
""" Block Decoder for readability, straight from the official TensorFlow repository """
@staticmethod
def _decode_block_string(block_string):
""" Gets a block through a string notation of arguments. """
assert isinstance(block_string, str)
ops = block_string.split('_')
options = {}
for op in ops:
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# Check stride
assert (('s' in options and len(options['s']) == 1) or
(len(options['s']) == 2 and options['s'][0] == options['s'][1]))
return BlockArgs(
kernel_size=int(options['k']),
num_repeat=int(options['r']),
input_filters=int(options['i']),
output_filters=int(options['o']),
expand_ratio=int(options['e']),
id_skip=('noskip' not in block_string),
se_ratio=float(options['se']) if 'se' in options else None,
stride=[int(options['s'][0])])
@staticmethod
def _encode_block_string(block):
"""Encodes a block to a string."""
args = [
'r%d' % block.num_repeat, 'k%d' % block.kernel_size, 's%d%d' %
(block.strides[0], block.strides[1]), 'e%s' % block.expand_ratio,
'i%d' % block.input_filters, 'o%d' % block.output_filters
]
if 0 < block.se_ratio <= 1:
args.append('se%s' % block.se_ratio)
if block.id_skip is False:
args.append('noskip')
return '_'.join(args)
@staticmethod
def decode(string_list):
"""
Decodes a list of string notations to specify blocks inside the network.
:param string_list: a list of strings, each string is a notation of block
:return: a list of BlockArgs namedtuples of block args
"""
assert isinstance(string_list, list)
blocks_args = []
for block_string in string_list:
blocks_args.append(BlockDecoder._decode_block_string(block_string))
return blocks_args
@staticmethod
def encode(blocks_args):
"""
Encodes a list of BlockArgs to a list of strings.
:param blocks_args: a list of BlockArgs namedtuples of block args
:return: a list of strings, each string is a notation of block
"""
block_strings = []
for block in blocks_args:
block_strings.append(BlockDecoder._encode_block_string(block))
return block_strings
def efficientnet(width_coefficient=None,
depth_coefficient=None,
dropout_rate=0.2,
drop_connect_rate=0.2,
image_size=None,
num_classes=1000):
""" Get block arguments according to parameter and coefficients. """
blocks_args = [
'r1_k3_s11_e1_i32_o16_se0.25',
'r2_k3_s22_e6_i16_o24_se0.25',
'r2_k5_s22_e6_i24_o40_se0.25',
'r3_k3_s22_e6_i40_o80_se0.25',
'r3_k5_s11_e6_i80_o112_se0.25',
'r4_k5_s22_e6_i112_o192_se0.25',
'r1_k3_s11_e6_i192_o320_se0.25',
]
blocks_args = BlockDecoder.decode(blocks_args)
global_params = GlobalParams(
batch_norm_momentum=0.99,
batch_norm_epsilon=1e-3,
dropout_rate=dropout_rate,
drop_connect_rate=drop_connect_rate,
num_classes=num_classes,
width_coefficient=width_coefficient,
depth_coefficient=depth_coefficient,
depth_divisor=8,
min_depth=None,
image_size=image_size, )
return blocks_args, global_params
def get_model_params(model_name, override_params):
""" Get the block args and global params for a given model """
if model_name.startswith('efficientnet'):
w, d, s, p = efficientnet_params(model_name)
blocks_args, global_params = efficientnet(
width_coefficient=w,
depth_coefficient=d,
dropout_rate=p,
image_size=s)
else:
raise NotImplementedError('model name is not pre-defined: %s' %
model_name)
if override_params:
global_params = global_params._replace(**override_params)
return blocks_args, global_params
url_map = {
'efficientnet-b0':
'/home/aistudio/data/weights/efficientnet-b0-355c32eb.pdparams',
'efficientnet-b1':
'/home/aistudio/data/weights/efficientnet-b1-f1951068.pdparams',
'efficientnet-b2':
'/home/aistudio/data/weights/efficientnet-b2-8bb594d6.pdparams',
'efficientnet-b3':
'/home/aistudio/data/weights/efficientnet-b3-5fb5a3c3.pdparams',
'efficientnet-b4':
'/home/aistudio/data/weights/efficientnet-b4-6ed6700e.pdparams',
'efficientnet-b5':
'/home/aistudio/data/weights/efficientnet-b5-b6417697.pdparams',
'efficientnet-b6':
'/home/aistudio/data/weights/efficientnet-b6-c76e70fd.pdparams',
'efficientnet-b7':
'/home/aistudio/data/weights/efficientnet-b7-dcc49843.pdparams',
}
url_map_advprop = {
'efficientnet-b0':
'/home/aistudio/data/weights/adv-efficientnet-b0-b64d5a18.pdparams',
'efficientnet-b1':
'/home/aistudio/data/weights/adv-efficientnet-b1-0f3ce85a.pdparams',
'efficientnet-b2':
'/home/aistudio/data/weights/adv-efficientnet-b2-6e9d97e5.pdparams',
'efficientnet-b3':
'/home/aistudio/data/weights/adv-efficientnet-b3-cdd7c0f4.pdparams',
'efficientnet-b4':
'/home/aistudio/data/weights/adv-efficientnet-b4-44fb3a87.pdparams',
'efficientnet-b5':
'/home/aistudio/data/weights/adv-efficientnet-b5-86493f6b.pdparams',
'efficientnet-b6':
'/home/aistudio/data/weights/adv-efficientnet-b6-ac80338e.pdparams',
'efficientnet-b7':
'/home/aistudio/data/weights/adv-efficientnet-b7-4652b6dd.pdparams',
'efficientnet-b8':
'/home/aistudio/data/weights/adv-efficientnet-b8-22a8fe65.pdparams',
}
def load_pretrained_weights(model,
model_name,
weights_path=None,
load_fc=True,
advprop=False):
"""Loads pretrained weights from weights path or download using url.
Args:
model (Module): The whole model of efficientnet.
model_name (str): Model name of efficientnet.
weights_path (None or str):
str: path to pretrained weights file on the local disk.
None: use pretrained weights downloaded from the Internet.
load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model.
advprop (bool): Whether to load pretrained weights
trained with advprop (valid when weights_path is None).
"""
# AutoAugment or Advprop (different preprocessing)
url_map_ = url_map_advprop if advprop else url_map
state_dict = paddle.load(url_map_[model_name])
if load_fc:
model.set_state_dict(state_dict)
else:
state_dict.pop('_fc.weight')
state_dict.pop('_fc.bias')
model.set_state_dict(state_dict)
print('Loaded pretrained weights for {}'.format(model_name))
......@@ -44,11 +44,6 @@ function update_pd_ops() {
cd ${PADDLE_ROOT}/tools/infrt/
python3 generate_pd_op_dialect_from_paddle_op_maker.py
python3 generate_phi_kernel_dialect.py
# generate test model
cd ${PADDLE_ROOT}
mkdir -p ${PADDLE_ROOT}/build/models
python3 paddle/infrt/tests/models/abs_model.py ${PADDLE_ROOT}/build/paddle/infrt/tests/abs
python3 paddle/infrt/tests/models/resnet50_model.py ${PADDLE_ROOT}/build/models/resnet50/model
}
function init() {
......@@ -114,6 +109,14 @@ function create_fake_models() {
# create multi_fc model, this will generate "multi_fc_model"
python3 -m pip uninstall -y paddlepaddle
python3 -m pip install *whl
# generate test model
cd ${PADDLE_ROOT}
mkdir -p ${PADDLE_ROOT}/build/models
python3 paddle/infrt/tests/models/abs_model.py ${PADDLE_ROOT}/build/paddle/infrt/tests/abs
python3 paddle/infrt/tests/models/resnet50_model.py ${PADDLE_ROOT}/build/models/resnet50/model
python3 paddle/infrt/tests/models/efficientnet-b4/model.py ${PADDLE_ROOT}/build/models/efficientnet-b4/model
cd ${PADDLE_ROOT}/build
python3 ${PADDLE_ROOT}/tools/infrt/fake_models/multi_fc.py
python3 ${PADDLE_ROOT}/paddle/infrt/tests/models/linear.py
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册