提交 3a8b5680 编写于 作者: H HydrogenSulfate

feat(model): add EfficientNetV2 code and fix AttrDict BUG

上级 7a0c7965
......@@ -33,7 +33,7 @@ class AttrDict(dict):
self[key] = value
def __deepcopy__(self, content):
return copy.deepcopy(dict(self))
return AttrDict(copy.deepcopy(dict(self)))
def create_attr_dict(yaml_config):
......
......@@ -38,6 +38,7 @@ from .model_zoo.dpn import DPN68, DPN92, DPN98, DPN107, DPN131
from .model_zoo.dsnet import DSNet_tiny, DSNet_small, DSNet_base
from .model_zoo.densenet import DenseNet121, DenseNet161, DenseNet169, DenseNet201, DenseNet264
from .model_zoo.efficientnet import EfficientNetB0, EfficientNetB1, EfficientNetB2, EfficientNetB3, EfficientNetB4, EfficientNetB5, EfficientNetB6, EfficientNetB7, EfficientNetB0_small
from .model_zoo.efficientnet_v2 import EfficientNetV2_S
from .model_zoo.resnest import ResNeSt50_fast_1s1x64d, ResNeSt50, ResNeSt101, ResNeSt200, ResNeSt269
from .model_zoo.googlenet import GoogLeNet
from .model_zoo.mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0
......
此差异已折叠。
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 100
eval_during_train: True
eval_interval: 1
epochs: 350
print_batch_step: 20
use_visualdl: False
# used for static mode and model export
image_shape: [3, 384, 384]
save_inference_dir: ./inference
train_mode: efficientnetv2 # progressive training
AMP:
scale_loss: 65536
use_dynamic_loss_scaling: True
# O1: mixed fp16
level: O1
EMA:
decay: 0.9999
# model architecture
Arch:
name: EfficientNetV2_S
class_num: 1000
use_sync_bn: True
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.65 # 8gpux128bs
warmup_epoch: 5
regularizer:
name: L2
coeff: 0.00001
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
scale: [0.05, 1.0]
size: 224
- RandFlipImage:
flip_code: 1
- RandAugmentV2:
num_layers: 2
magnitude: 5
- NormalizeImage:
scale: 1.0
mean: [128.0, 128.0, 128.0]
std: [128.0, 128.0, 128.0]
order: ""
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: True
shuffle: True
loader:
num_workers: 8
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- CropImageAtRatio:
size: 384
pad: 32
interpolation: bilinear
- NormalizeImage:
scale: 1.0
mean: [128.0, 128.0, 128.0]
std: [128.0, 128.0, 128.0]
order: ""
sampler:
name: DistributedBatchSampler
batch_size: 128
drop_last: False
shuffle: False
loader:
num_workers: 8
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- CropImageAtRatio:
size: 384
pad: 32
interpolation: bilinear
- NormalizeImage:
scale: 1.0
mean: [128.0, 128.0, 128.0]
std: [128.0, 128.0, 128.0]
order: ""
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
......@@ -15,6 +15,7 @@
from ppcls.data.preprocess.ops.autoaugment import ImageNetPolicy as RawImageNetPolicy
from ppcls.data.preprocess.ops.randaugment import RandAugment as RawRandAugment
from ppcls.data.preprocess.ops.randaugment import RandomApply
from ppcls.data.preprocess.ops.randaugment import RandAugmentV2 as RawRandAugmentV2
from ppcls.data.preprocess.ops.timm_autoaugment import RawTimmAutoAugment
from ppcls.data.preprocess.ops.cutout import Cutout
......@@ -25,6 +26,7 @@ from ppcls.data.preprocess.ops.grid import GridMask
from ppcls.data.preprocess.ops.operators import DecodeImage
from ppcls.data.preprocess.ops.operators import ResizeImage
from ppcls.data.preprocess.ops.operators import CropImage
from ppcls.data.preprocess.ops.operators import CropImageAtRatio
from ppcls.data.preprocess.ops.operators import CenterCrop, Resize
from ppcls.data.preprocess.ops.operators import RandCropImage
from ppcls.data.preprocess.ops.operators import RandCropImageV2
......@@ -101,6 +103,13 @@ class RandAugment(RawRandAugment):
return img
class RandAugmentV2(RawRandAugmentV2):
""" RandAugmentV2 wrapper to auto fit different img types """
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
class TimmAutoAugment(RawTimmAutoAugment):
""" TimmAutoAugment wrapper to auto fit different img tyeps. """
......
......@@ -319,6 +319,25 @@ class CropImage(object):
return img[h_start:h_end, w_start:w_end, :]
class CropImageAtRatio(object):
""" crop image with specified size and padding"""
def __init__(self, size: int, pad: int, interpolation="bilinear"):
self.size = size
self.ratio = size / (size + pad)
self.interpolation = interpolation
def __call__(self, img):
height, width = img.shape[:2]
crop_size = int(self.ratio * min(height, width))
y = (height - crop_size) // 2
x = (width - crop_size) // 2
crop_img = img[y:y + crop_size, x:x + crop_size, :]
return F.resize(crop_img, [self.size, self.size], self.interpolation)
class Padv2(object):
def __init__(self,
size=None,
......
......@@ -15,12 +15,60 @@
# This code is based on https://github.com/heartInsert/randaugment
# reference: https://arxiv.org/abs/1909.13719
from PIL import Image, ImageEnhance, ImageOps
import numpy as np
import random
from .operators import RawColorJitter
from paddle.vision.transforms import transforms as T
import numpy as np
from PIL import Image, ImageEnhance, ImageOps
def solarize_add(img, add, thresh=128, **__):
lut = []
for i in range(256):
if i < thresh:
lut.append(min(255, i + add))
else:
lut.append(i)
if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut
return img.point(lut)
else:
return img
def cutout(image, pad_size, replace=0):
image_np = np.array(image)
image_height, image_width, _ = image_np.shape
# Sample the center location in the image where the zero mask will be applied.
cutout_center_height = np.random.randint(0, image_height + 1)
cutout_center_width = np.random.randint(0, image_width + 1)
lower_pad = np.maximum(0, cutout_center_height - pad_size)
upper_pad = np.maximum(0, image_height - cutout_center_height - pad_size)
left_pad = np.maximum(0, cutout_center_width - pad_size)
right_pad = np.maximum(0, image_width - cutout_center_width - pad_size)
cutout_shape = [
image_height - (lower_pad + upper_pad),
image_width - (left_pad + right_pad)
]
padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
mask = np.pad(np.zeros(
cutout_shape, dtype=image_np.dtype),
padding_dims,
constant_values=1)
mask = np.expand_dims(mask, -1)
mask = np.tile(mask, [1, 1, 3])
image_np = np.where(
np.equal(mask, 0),
np.full_like(
image_np, fill_value=replace, dtype=image_np.dtype),
image_np)
return Image.fromarray(image_np)
class RandAugment(object):
def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128)):
......@@ -95,10 +143,10 @@ class RandAugment(object):
"brightness": lambda img, magnitude:
ImageEnhance.Brightness(img).enhance(
1 + magnitude * rnd_ch_op([-1, 1])),
"autocontrast": lambda img, magnitude:
"autocontrast": lambda img, _:
ImageOps.autocontrast(img),
"equalize": lambda img, magnitude: ImageOps.equalize(img),
"invert": lambda img, magnitude: ImageOps.invert(img)
"equalize": lambda img, _: ImageOps.equalize(img),
"invert": lambda img, _: ImageOps.invert(img)
}
def __call__(self, img):
......@@ -121,4 +169,85 @@ class RandomApply(object):
def __call__(self, img):
timg = self.trans(img)
return timg
\ No newline at end of file
return timg
## RandAugment_EfficientNetV2 code below ##
class RandAugmentV2(RandAugment):
"""Customed RandAugment for EfficientNetV2"""
def __init__(self, num_layers=2, magnitude=5, fillcolor=(128, 128, 128)):
super().__init__(num_layers, magnitude, fillcolor)
abso_level = self.magnitude / self.max_level # [5.0~10.0/10.0]=[0.5, 1.0]
self.level_map = {
"shearX": 0.3 * abso_level,
"shearY": 0.3 * abso_level,
"translateX": 100.0 * abso_level,
"translateY": 100.0 * abso_level,
"rotate": 30 * abso_level,
"color": 1.8 * abso_level + 0.1,
"posterize": int(4.0 * abso_level),
"solarize": int(256.0 * abso_level),
"solarize_add": int(110.0 * abso_level),
"contrast": 1.8 * abso_level + 0.1,
"sharpness": 1.8 * abso_level + 0.1,
"brightness": 1.8 * abso_level + 0.1,
"autocontrast": 0,
"equalize": 0,
"invert": 0,
"cutout": int(40 * abso_level)
}
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot,
Image.new("RGBA", rot.size, (128, ) * 4),
rot).convert(img.mode)
rnd_ch_op = random.choice
self.func = {
"shearX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, magnitude * rnd_ch_op([-1, 1]), 0, 0, 1, 0),
Image.NEAREST,
fillcolor=fillcolor),
"shearY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, magnitude * rnd_ch_op([-1, 1]), 1, 0),
Image.NEAREST,
fillcolor=fillcolor),
"translateX": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, magnitude * rnd_ch_op([-1, 1]), 0, 1, 0),
Image.NEAREST,
fillcolor=fillcolor),
"translateY": lambda img, magnitude: img.transform(
img.size,
Image.AFFINE,
(1, 0, 0, 0, 1, magnitude * rnd_ch_op([-1, 1])),
Image.NEAREST,
fillcolor=fillcolor),
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude * rnd_ch_op([-1, 1])),
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(magnitude),
"posterize": lambda img, magnitude:
ImageOps.posterize(img, magnitude),
"solarize": lambda img, magnitude:
ImageOps.solarize(img, magnitude),
"solarize_add": lambda img, magnitude:
solarize_add(img, magnitude),
"contrast": lambda img, magnitude:
ImageEnhance.Contrast(img).enhance(magnitude),
"sharpness": lambda img, magnitude:
ImageEnhance.Sharpness(img).enhance(magnitude),
"brightness": lambda img, magnitude:
ImageEnhance.Brightness(img).enhance(magnitude),
"autocontrast": lambda img, _:
ImageOps.autocontrast(img),
"equalize": lambda img, _: ImageOps.equalize(img),
"invert": lambda img, _: ImageOps.invert(img),
"cutout": lambda img, magnitude: cutout(img, magnitude, replace=fillcolor[0])
}
......@@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ppcls.engine.train.train import train_epoch
from ppcls.engine.train.train_efficientnetv2 import train_epoch_efficientnetv2
from ppcls.engine.train.train_fixmatch import train_epoch_fixmatch
from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
\ No newline at end of file
from ppcls.engine.train.train_fixmatch_ccssl import train_epoch_fixmatch_ccssl
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function
import time
import numpy as np
from ppcls.data import build_dataloader
from ppcls.utils import logger
from .train import train_epoch
def train_epoch_efficientnetv2(engine, epoch_id, print_batch_step):
# 1. Build training hyper-parameters for different training stage
num_stage = 4
ratio_list = [(i + 1) / num_stage for i in range(num_stage)]
ram_list = np.linspace(5, 10, num_stage)
# dropout_rate_list = np.linspace(0.0, 0.2, num_stage)
stones = [
int(engine.config["Global"]["epochs"] * ratio_list[i])
for i in range(num_stage)
]
image_size_list = [
int(128 + (300 - 128) * ratio_list[i]) for i in range(num_stage)
]
stage_id = 0
for i in range(num_stage):
if epoch_id > stones[i]:
stage_id = i + 1
# 2. Adjust training hyper-parameters for different training stage
if not hasattr(engine, 'last_stage') or engine.last_stage < stage_id:
engine.config["DataLoader"]["Train"]["dataset"]["transform_ops"][1][
"RandCropImage"]["size"] = image_size_list[stage_id]
engine.config["DataLoader"]["Train"]["dataset"]["transform_ops"][3][
"RandAugment"]["magnitude"] = ram_list[stage_id]
engine.train_dataloader = build_dataloader(
engine.config["DataLoader"],
"Train",
engine.device,
engine.use_dali,
seed=epoch_id)
engine.train_dataloader_iter = iter(engine.train_dataloader)
engine.last_stage = stage_id
logger.info(
f"Training stage: [{stage_id+1}/{num_stage}](random_aug_magnitude={ram_list[stage_id]}, train_image_size={image_size_list[stage_id]})"
)
# 3. Train one epoch as usual at current stage
train_epoch(engine, epoch_id, print_batch_step)
......@@ -33,7 +33,7 @@ class AttrDict(dict):
self[key] = value
def __deepcopy__(self, content):
return copy.deepcopy(dict(self))
return AttrDict(copy.deepcopy(dict(self)))
def create_attr_dict(yaml_config):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册