提交 c351dac6 编写于 作者: Y Yang Nie 提交者: Tingquan Gao

add tinynet

上级 f8fdc5fd
......@@ -80,6 +80,7 @@ from .model_zoo.micronet import MicroNet_M0, MicroNet_M1, MicroNet_M2, MicroNet_
from .model_zoo.mobilenext import MobileNeXt_x0_35, MobileNeXt_x0_5, MobileNeXt_x0_75, MobileNeXt_x1_0, MobileNeXt_x1_4
from .model_zoo.mobilevit_v2 import MobileViTV2_x0_5, MobileViTV2_x0_75, MobileViTV2_x1_0, MobileViTV2_x1_25, MobileViTV2_x1_5, MobileViTV2_x1_75, MobileViTV2_x2_0
from .model_zoo.mobilevit_v3 import MobileViTv3_XXS, MobileViTv3_XS, MobileViTv3_S, MobileViTv3_XXS_L2, MobileViTv3_XS_L2, MobileViTv3_S_L2, MobileViTv3_x0_5, MobileViTv3_x0_75, MobileViTv3_x1_0
from .model_zoo.tinynet import TinyNet_A, TinyNet_B, TinyNet_C, TinyNet_D, TinyNet_E
from .variant_models.resnet_variant import ResNet50_last_stage_stride1
from .variant_models.resnet_variant import ResNet50_adaptive_max_pool2d
......
......@@ -60,6 +60,7 @@ GlobalParams = collections.namedtuple('GlobalParams', [
'width_coefficient',
'depth_coefficient',
'depth_divisor',
'depth_trunc',
'min_depth',
'drop_connect_rate',
])
......@@ -77,6 +78,7 @@ def efficientnet_params(model_name):
""" Map EfficientNet model name to parameter coefficients. """
params_dict = {
# Coefficients: width,depth,resolution,dropout
'efficientnet-b0-small': (1.0, 1.0, 224, 0.2),
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
......@@ -114,6 +116,7 @@ def efficientnet(width_coefficient=None,
width_coefficient=width_coefficient,
depth_coefficient=depth_coefficient,
depth_divisor=8,
depth_trunc='ceil',
min_depth=None)
return blocks_args, global_params
......@@ -154,7 +157,10 @@ def round_repeats(repeats, global_params):
multiplier = global_params.depth_coefficient
if not multiplier:
return repeats
return int(math.ceil(multiplier * repeats))
if global_params.depth_trunc == 'round':
return max(1, round(multiplier * repeats))
else:
return int(math.ceil(multiplier * repeats))
class BlockDecoder(object):
......@@ -314,10 +320,10 @@ class Conv2ds(TheseusLayer):
padding = ((stride - 1) + dilation * (filter_size - 1)) // 2
return padding
inps = 1 if model_name == None and cur_stage == None else inp_shape[
model_name][cur_stage]
self.need_crop = False
if padding_type == "SAME":
inps = 1 if model_name == None and cur_stage == None else inp_shape[
model_name][cur_stage]
top_padding, bottom_padding = cal_padding(inps, stride,
filter_size)
left_padding, right_padding = cal_padding(inps, stride,
......@@ -398,12 +404,13 @@ class ConvBNLayer(TheseusLayer):
if use_bn is True:
bn_name = name + bn_name
param_attr, bias_attr = init_batch_norm_layer(bn_name)
momentum = global_params.batch_norm_momentum
epsilon = global_params.batch_norm_epsilon
self._bn = BatchNorm(
num_channels=output_channels,
act=bn_act,
momentum=0.99,
momentum=momentum,
epsilon=epsilon,
moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance",
......@@ -501,12 +508,12 @@ class ProjectConvNorm(TheseusLayer):
cur_stage=None):
super(ProjectConvNorm, self).__init__()
final_oup = block_args.output_filters
self.final_oup = block_args.output_filters
self._conv = ConvBNLayer(
input_channels,
1,
final_oup,
self.final_oup,
global_params=global_params,
bn_act=None,
padding_type=padding_type,
......@@ -619,6 +626,8 @@ class MbConvBlock(TheseusLayer):
model_name=model_name,
cur_stage=cur_stage)
self.final_oup = self._pcn.final_oup
def forward(self, inputs):
x = inputs
if self.expand_ratio != 1:
......@@ -647,10 +656,11 @@ class ConvStemNorm(TheseusLayer):
_global_params,
name=None,
model_name=None,
fix_stem=False,
cur_stage=None):
super(ConvStemNorm, self).__init__()
output_channels = round_filters(32, _global_params)
output_channels = 32 if fix_stem else round_filters(32, _global_params)
self._conv = ConvBNLayer(
input_channels,
filter_size=3,
......@@ -676,7 +686,8 @@ class ExtractFeatures(TheseusLayer):
_global_params,
padding_type,
use_se,
model_name=None):
model_name=None,
fix_stem=False):
super(ExtractFeatures, self).__init__()
self._global_params = _global_params
......@@ -686,6 +697,7 @@ class ExtractFeatures(TheseusLayer):
padding_type=padding_type,
_global_params=_global_params,
model_name=model_name,
fix_stem=fix_stem,
cur_stage=0)
self.block_args_copy = copy.deepcopy(_block_args)
......@@ -702,12 +714,14 @@ class ExtractFeatures(TheseusLayer):
for _ in range(block_arg.num_repeat - 1):
block_size += 1
self.final_oup = None
self.conv_seq = []
cur_stage = 1
for block_args in _block_args:
for block_idx, block_args in enumerate(_block_args):
if not (fix_stem and block_idx == 0):
block_args = block_args._replace(input_filters=round_filters(
block_args.input_filters, _global_params))
block_args = block_args._replace(
input_filters=round_filters(block_args.input_filters,
_global_params),
output_filters=round_filters(block_args.output_filters,
_global_params),
num_repeat=round_repeats(block_args.num_repeat,
......@@ -730,6 +744,7 @@ class ExtractFeatures(TheseusLayer):
model_name=model_name,
cur_stage=cur_stage))
self.conv_seq.append(_mc_block)
self.final_oup = _mc_block.final_oup
idx += 1
if block_args.num_repeat > 1:
block_args = block_args._replace(
......@@ -751,6 +766,7 @@ class ExtractFeatures(TheseusLayer):
model_name=model_name,
cur_stage=cur_stage))
self.conv_seq.append(_mc_block)
self.final_oup = _mc_block.final_oup
idx += 1
cur_stage += 1
......@@ -764,17 +780,20 @@ class ExtractFeatures(TheseusLayer):
class EfficientNet(TheseusLayer):
def __init__(self,
block_args,
global_params,
name="b0",
padding_type="SAME",
override_params=None,
use_se=True,
fix_stem=False,
num_features=None,
class_num=1000):
super(EfficientNet, self).__init__()
model_name = 'efficientnet-' + name
self.name = name
self._block_args, self._global_params = get_model_params(
model_name, override_params)
self.fix_stem = fix_stem
self._block_args = block_args
self._global_params = global_params
self.padding_type = padding_type
self.use_se = use_se
......@@ -784,25 +803,13 @@ class EfficientNet(TheseusLayer):
self._global_params,
self.padding_type,
self.use_se,
model_name=self.name)
output_channels = round_filters(1280, self._global_params)
if name == "b0_small" or name == "b0" or name == "b1":
oup = 320
elif name == "b2":
oup = 352
elif name == "b3":
oup = 384
elif name == "b4":
oup = 448
elif name == "b5":
oup = 512
elif name == "b6":
oup = 576
elif name == "b7":
oup = 640
model_name=self.name,
fix_stem=self.fix_stem)
output_channels = num_features or round_filters(1280,
self._global_params)
self._conv = ConvBNLayer(
oup,
self._ef.final_oup,
1,
output_channels,
global_params=self._global_params,
......@@ -856,10 +863,13 @@ def EfficientNetB0_small(padding_type='DYNAMIC',
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("efficientnet-b0-small",
override_params)
model = EfficientNet(
block_args,
global_params,
name='b0',
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB0_small"])
......@@ -872,10 +882,13 @@ def EfficientNetB0(padding_type='SAME',
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("efficientnet-b0",
override_params)
model = EfficientNet(
block_args,
global_params,
name='b0',
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB0"])
......@@ -888,10 +901,13 @@ def EfficientNetB1(padding_type='SAME',
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("efficientnet-b1",
override_params)
model = EfficientNet(
block_args,
global_params,
name='b1',
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB1"])
......@@ -904,10 +920,13 @@ def EfficientNetB2(padding_type='SAME',
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("efficientnet-b2",
override_params)
model = EfficientNet(
block_args,
global_params,
name='b2',
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB2"])
......@@ -920,10 +939,13 @@ def EfficientNetB3(padding_type='SAME',
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("efficientnet-b3",
override_params)
model = EfficientNet(
block_args,
global_params,
name='b3',
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB3"])
......@@ -936,10 +958,13 @@ def EfficientNetB4(padding_type='SAME',
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("efficientnet-b4",
override_params)
model = EfficientNet(
block_args,
global_params,
name='b4',
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB4"])
......@@ -952,10 +977,13 @@ def EfficientNetB5(padding_type='SAME',
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("efficientnet-b5",
override_params)
model = EfficientNet(
block_args,
global_params,
name='b5',
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB5"])
......@@ -968,10 +996,13 @@ def EfficientNetB6(padding_type='SAME',
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("efficientnet-b6",
override_params)
model = EfficientNet(
block_args,
global_params,
name='b6',
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB6"])
......@@ -984,10 +1015,13 @@ def EfficientNetB7(padding_type='SAME',
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("efficientnet-b7",
override_params)
model = EfficientNet(
block_args,
global_params,
name='b7',
padding_type=padding_type,
override_params=override_params,
use_se=use_se,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["EfficientNetB7"])
......
# copyright (c) 2023 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Code was based on https://gitee.com/mindspore/models/tree/master/research/cv/tinynet
# reference: https://arxiv.org/abs/2010.14819
import paddle.nn as nn
from .efficientnet import EfficientNet, efficientnet
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
"TinyNet_A": "",
"TinyNet_B": "",
"TinyNet_C": "",
"TinyNet_D": "",
"TinyNet_E": "",
}
__all__ = list(MODEL_URLS.keys())
def tinynet_params(model_name):
""" Map TinyNet model name to parameter coefficients. """
params_dict = {
# Coefficients: width,depth,resolution,dropout
"tinynet-a": (1.00, 1.200, 192, 0.2),
"tinynet-b": (0.75, 1.100, 188, 0.2),
"tinynet-c": (0.54, 0.850, 184, 0.2),
"tinynet-d": (0.54, 0.695, 152, 0.2),
"tinynet-e": (0.51, 0.600, 106, 0.2),
}
return params_dict[model_name]
def get_model_params(model_name, override_params):
""" Get the block args and global params for a given model """
if model_name.startswith('tinynet'):
w, d, _, p = tinynet_params(model_name)
blocks_args, global_params = efficientnet(
width_coefficient=w, depth_coefficient=d, dropout_rate=p)
else:
raise NotImplementedError('model name is not pre-defined: %s' %
model_name)
if override_params:
global_params = global_params._replace(**override_params)
return blocks_args, global_params
class TinyNet(EfficientNet):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Conv2D):
fin_in = m.weight.shape[1] * m.weight.shape[2] * m.weight.shape[3]
std = (2 / fin_in)**0.5
nn.initializer.Normal(std=std)(m.weight)
if m.bias is not None:
nn.initializer.Constant(0)(m.bias)
elif isinstance(m, nn.Linear):
fin_in = m.weight.shape[0]
bound = 1 / fin_in**0.5
nn.initializer.Uniform(-bound, bound)(m.weight)
if m.bias is not None:
nn.initializer.Constant(0)(m.bias)
def _load_pretrained(pretrained, model, model_url, use_ssld=False):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else:
raise RuntimeError(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def TinyNet_A(padding_type='DYNAMIC',
override_params=None,
use_se=True,
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("tinynet-a", override_params)
model = TinyNet(
block_args,
global_params,
name='a',
padding_type=padding_type,
use_se=use_se,
fix_stem=True,
num_features=1280,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["TinyNet_A"], use_ssld)
return model
def TinyNet_B(padding_type='DYNAMIC',
override_params=None,
use_se=True,
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("tinynet-b", override_params)
model = TinyNet(
block_args,
global_params,
name='b',
padding_type=padding_type,
use_se=use_se,
fix_stem=True,
num_features=1280,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["TinyNet_B"], use_ssld)
return model
def TinyNet_C(padding_type='DYNAMIC',
override_params=None,
use_se=True,
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("tinynet-c", override_params)
model = TinyNet(
block_args,
global_params,
name='c',
padding_type=padding_type,
use_se=use_se,
fix_stem=True,
num_features=1280,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["TinyNet_C"], use_ssld)
return model
def TinyNet_D(padding_type='DYNAMIC',
override_params=None,
use_se=True,
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("tinynet-d", override_params)
model = TinyNet(
block_args,
global_params,
name='d',
padding_type=padding_type,
use_se=use_se,
fix_stem=True,
num_features=1280,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["TinyNet_D"], use_ssld)
return model
def TinyNet_E(padding_type='DYNAMIC',
override_params=None,
use_se=True,
pretrained=False,
use_ssld=False,
**kwargs):
block_args, global_params = get_model_params("tinynet-e", override_params)
model = TinyNet(
block_args,
global_params,
name='e',
padding_type=padding_type,
use_se=use_se,
fix_stem=True,
num_features=1280,
**kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["TinyNet_E"], use_ssld)
return model
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 450
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 192, 192]
save_inference_dir: ./inference
# model ema
EMA:
decay: 0.9999
# model architecture
Arch:
name: TinyNet_A
class_num: 1000
override_params:
batch_norm_momentum: 0.9
batch_norm_epsilon: 1e-5
depth_trunc: round
drop_connect_rate: 0.1
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: RMSProp
momentum: 0.9
rho: 0.9
epsilon: 0.001
one_dim_param_no_weight_decay: True
lr:
name: Step
learning_rate: 0.048
step_size: 2.4
gamma: 0.97
warmup_epoch: 3
warmup_start_lr: 1e-6
regularizer:
name: 'L2'
coeff: 1e-5
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
backend: pil
- RandCropImage:
size: 192
interpolation: bicubic
backend: pil
use_log_aspect: True
- RandFlipImage:
flip_code: 1
- ColorJitter:
brightness: 0.4
contrast: 0.4
saturation: 0.4
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: True
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_np: False
channel_first: False
backend: pil
- ResizeImage:
resize_short: 219
interpolation: bicubic
backend: pil
- CropImage:
size: 192
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_np: False
channel_first: False
- ResizeImage:
resize_short: 219
interpolation: bicubic
backend: pil
- CropImage:
size: 192
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 450
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 188, 188]
save_inference_dir: ./inference
# model ema
EMA:
decay: 0.9999
# model architecture
Arch:
name: TinyNet_B
class_num: 1000
override_params:
batch_norm_momentum: 0.9
batch_norm_epsilon: 1e-5
depth_trunc: round
drop_connect_rate: 0.1
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: RMSProp
momentum: 0.9
rho: 0.9
epsilon: 0.001
one_dim_param_no_weight_decay: True
lr:
name: Step
learning_rate: 0.048
step_size: 2.4
gamma: 0.97
warmup_epoch: 3
warmup_start_lr: 1e-6
regularizer:
name: 'L2'
coeff: 1e-5
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
backend: pil
- RandCropImage:
size: 188
interpolation: bicubic
backend: pil
use_log_aspect: True
- RandFlipImage:
flip_code: 1
- ColorJitter:
brightness: 0.4
contrast: 0.4
saturation: 0.4
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: True
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_np: False
channel_first: False
backend: pil
- ResizeImage:
resize_short: 214
interpolation: bicubic
backend: pil
- CropImage:
size: 188
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_np: False
channel_first: False
- ResizeImage:
resize_short: 214
interpolation: bicubic
backend: pil
- CropImage:
size: 188
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 450
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 184, 184]
save_inference_dir: ./inference
# model ema
EMA:
decay: 0.9999
# model architecture
Arch:
name: TinyNet_C
class_num: 1000
override_params:
batch_norm_momentum: 0.9
batch_norm_epsilon: 1e-5
depth_trunc: round
drop_connect_rate: 0.0
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: RMSProp
momentum: 0.9
rho: 0.9
epsilon: 0.001
one_dim_param_no_weight_decay: True
lr:
name: Step
learning_rate: 0.048
step_size: 2.4
gamma: 0.97
warmup_epoch: 3
warmup_start_lr: 1e-6
regularizer:
name: 'L2'
coeff: 1e-5
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
backend: pil
- RandCropImage:
size: 184
interpolation: bicubic
backend: pil
use_log_aspect: True
- RandFlipImage:
flip_code: 1
- ColorJitter:
brightness: 0.4
contrast: 0.4
saturation: 0.4
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: True
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_np: False
channel_first: False
backend: pil
- ResizeImage:
resize_short: 210
interpolation: bicubic
backend: pil
- CropImage:
size: 184
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_np: False
channel_first: False
- ResizeImage:
resize_short: 210
interpolation: bicubic
backend: pil
- CropImage:
size: 184
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 450
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 152, 152]
save_inference_dir: ./inference
# model ema
EMA:
decay: 0.9999
# model architecture
Arch:
name: TinyNet_D
class_num: 1000
override_params:
batch_norm_momentum: 0.9
batch_norm_epsilon: 1e-5
depth_trunc: round
drop_connect_rate: 0.0
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: RMSProp
momentum: 0.9
rho: 0.9
epsilon: 0.001
one_dim_param_no_weight_decay: True
lr:
name: Step
learning_rate: 0.048
step_size: 2.4
gamma: 0.97
warmup_epoch: 3
warmup_start_lr: 1e-6
regularizer:
name: 'L2'
coeff: 1e-5
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
backend: pil
- RandCropImage:
size: 152
interpolation: bicubic
backend: pil
use_log_aspect: True
- RandFlipImage:
flip_code: 1
- ColorJitter:
brightness: 0.4
contrast: 0.4
saturation: 0.4
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: True
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_np: False
channel_first: False
backend: pil
- ResizeImage:
resize_short: 173
interpolation: bicubic
backend: pil
- CropImage:
size: 152
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_np: False
channel_first: False
- ResizeImage:
resize_short: 173
interpolation: bicubic
backend: pil
- CropImage:
size: 152
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 450
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 106, 106]
save_inference_dir: ./inference
# model ema
EMA:
decay: 0.9999
# model architecture
Arch:
name: TinyNet_E
class_num: 1000
override_params:
batch_norm_momentum: 0.9
batch_norm_epsilon: 1e-5
depth_trunc: round
drop_connect_rate: 0.0
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: RMSProp
momentum: 0.9
rho: 0.9
epsilon: 0.001
one_dim_param_no_weight_decay: True
lr:
name: Step
learning_rate: 0.048
step_size: 2.4
gamma: 0.97
warmup_epoch: 3
warmup_start_lr: 1e-6
regularizer:
name: 'L2'
coeff: 1e-5
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
backend: pil
- RandCropImage:
size: 106
interpolation: bicubic
backend: pil
use_log_aspect: True
- RandFlipImage:
flip_code: 1
- ColorJitter:
brightness: 0.4
contrast: 0.4
saturation: 0.4
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: True
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_np: False
channel_first: False
backend: pil
- ResizeImage:
resize_short: 121
interpolation: bicubic
backend: pil
- CropImage:
size: 106
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_np: False
channel_first: False
- ResizeImage:
resize_short: 121
interpolation: bicubic
backend: pil
- CropImage:
size: 106
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
......@@ -339,7 +339,7 @@ class Step(LRBase):
epochs (int): total epoch(s)
step_each_epoch (int): number of iterations within an epoch
learning_rate (float): learning rate
step_size (int): the interval to update.
step_size (int|float): the interval to update.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma``. It should be less than 1.0. Default: 0.1.
warmup_epoch (int, optional): The epoch numbers for LinearWarmup. Defaults to 0.
warmup_start_lr (float, optional): start learning rate within warmup. Defaults to 0.0.
......@@ -361,7 +361,7 @@ class Step(LRBase):
super(Step, self).__init__(epochs, step_each_epoch, learning_rate,
warmup_epoch, warmup_start_lr, last_epoch,
by_epoch)
self.step_size = step_size * step_each_epoch
self.step_size = int(step_size * step_each_epoch)
self.gamma = gamma
if self.by_epoch:
self.step_size = step_size
......
......@@ -215,7 +215,9 @@ class RMSProp(object):
epsilon=1e-6,
weight_decay=None,
grad_clip=None,
multi_precision=False):
multi_precision=False,
no_weight_decay_name=None,
one_dim_param_no_weight_decay=False):
super().__init__()
self.learning_rate = learning_rate
self.momentum = momentum
......@@ -223,11 +225,33 @@ class RMSProp(object):
self.epsilon = epsilon
self.weight_decay = weight_decay
self.grad_clip = grad_clip
self.no_weight_decay_name_list = no_weight_decay_name.split(
) if no_weight_decay_name else []
self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay
def __call__(self, model_list):
# model_list is None in static graph
parameters = sum([m.parameters() for m in model_list],
[]) if model_list else None
parameters = None
if len(self.no_weight_decay_name_list) > 0:
params_with_decay = []
params_without_decay = []
for m in model_list:
params = [p for n, p in m.named_parameters() \
if not any(nd in n for nd in self.no_weight_decay_name_list)]
params_with_decay.extend(params)
params = [p for n, p in m.named_parameters() \
if any(nd in n for nd in self.no_weight_decay_name_list) or (self.one_dim_param_no_weight_decay and len(p.shape) == 1)]
params_without_decay.extend(params)
parameters = [{
"params": params_with_decay,
"weight_decay": self.weight_decay
}, {
"params": params_without_decay,
"weight_decay": 0.0
}]
else:
parameters = sum([m.parameters() for m in model_list],
[]) if model_list else None
opt = optim.RMSProp(
learning_rate=self.learning_rate,
momentum=self.momentum,
......
===========================train_params===========================
model_name:TinyNet_A
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/TinyNet/TinyNet_A.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False -o Global.eval_during_train=False -o Global.save_interval=2
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/TinyNet/TinyNet_A.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/TinyNet/TinyNet_A.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
inference_dir:null
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:False
-o Global.cpu_num_threads:1
-o Global.batch_size:1
-o Global.use_tensorrt:False
-o Global.use_fp16:False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val/ILSVRC2012_val_00000001.JPEG
-o Global.save_log_path:null
-o Global.benchmark:False
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,192,192]}]
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册