未验证 提交 9d9cd371 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add static training (#1037)

* add static training

* fix typo

* add se fp16

* rm note

* fix loader

* fix cfg
上级 73004f78
......@@ -104,7 +104,8 @@ class ConvBNLayer(TheseusLayer):
groups=1,
is_vd_mode=False,
act=None,
lr_mult=1.0):
lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.is_vd_mode = is_vd_mode
self.act = act
......@@ -118,11 +119,13 @@ class ConvBNLayer(TheseusLayer):
padding=(filter_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=False)
bias_attr=False,
data_format=data_format)
self.bn = BatchNorm(
num_filters,
param_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=ParamAttr(learning_rate=lr_mult))
bias_attr=ParamAttr(learning_rate=lr_mult),
data_layout=data_format)
self.relu = nn.ReLU()
def forward(self, x):
......@@ -136,14 +139,14 @@ class ConvBNLayer(TheseusLayer):
class BottleneckBlock(TheseusLayer):
def __init__(
self,
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
if_first=False,
lr_mult=1.0, ):
lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.conv0 = ConvBNLayer(
......@@ -151,20 +154,23 @@ class BottleneckBlock(TheseusLayer):
num_filters=num_filters,
filter_size=1,
act="relu",
lr_mult=lr_mult)
lr_mult=lr_mult,
data_format=data_format)
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
stride=stride,
act="relu",
lr_mult=lr_mult)
lr_mult=lr_mult,
data_format=data_format)
self.conv2 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
act=None,
lr_mult=lr_mult)
lr_mult=lr_mult,
data_format=data_format)
if not shortcut:
self.short = ConvBNLayer(
......@@ -173,7 +179,8 @@ class BottleneckBlock(TheseusLayer):
filter_size=1,
stride=stride if if_first else 1,
is_vd_mode=False if if_first else True,
lr_mult=lr_mult)
lr_mult=lr_mult,
data_format=data_format)
self.relu = nn.ReLU()
self.shortcut = shortcut
......@@ -199,7 +206,8 @@ class BasicBlock(TheseusLayer):
stride,
shortcut=True,
if_first=False,
lr_mult=1.0):
lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.stride = stride
......@@ -209,13 +217,15 @@ class BasicBlock(TheseusLayer):
filter_size=3,
stride=stride,
act="relu",
lr_mult=lr_mult)
lr_mult=lr_mult,
data_format=data_format)
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
act=None,
lr_mult=lr_mult)
lr_mult=lr_mult,
data_format=data_format)
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
......@@ -223,7 +233,8 @@ class BasicBlock(TheseusLayer):
filter_size=1,
stride=stride if if_first else 1,
is_vd_mode=False if if_first else True,
lr_mult=lr_mult)
lr_mult=lr_mult,
data_format=data_format)
self.shortcut = shortcut
self.relu = nn.ReLU()
......@@ -256,7 +267,9 @@ class ResNet(TheseusLayer):
config,
version="vb",
class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]):
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
data_format="NCHW",
input_image_channel=3):
super().__init__()
self.cfg = config
......@@ -279,22 +292,25 @@ class ResNet(TheseusLayer):
self.stem_cfg = {
#num_channels, num_filters, filter_size, stride
"vb": [[3, 64, 7, 2]],
"vd": [[3, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
"vb": [[input_image_channel, 64, 7, 2]],
"vd":
[[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
}
self.stem = nn.Sequential(*[
self.stem = nn.Sequential(* [
ConvBNLayer(
num_channels=in_c,
num_filters=out_c,
filter_size=k,
stride=s,
act="relu",
lr_mult=self.lr_mult_list[0])
lr_mult=self.lr_mult_list[0],
data_format=data_format)
for in_c, out_c, k, s in self.stem_cfg[version]
])
self.max_pool = MaxPool2D(kernel_size=3, stride=2, padding=1)
self.max_pool = MaxPool2D(
kernel_size=3, stride=2, padding=1, data_format=data_format)
block_list = []
for block_idx in range(len(self.block_depth)):
shortcut = False
......@@ -306,11 +322,12 @@ class ResNet(TheseusLayer):
stride=2 if i == 0 and block_idx != 0 else 1,
shortcut=shortcut,
if_first=block_idx == i == 0 if version == "vd" else True,
lr_mult=self.lr_mult_list[block_idx + 1]))
lr_mult=self.lr_mult_list[block_idx + 1],
data_format=data_format))
shortcut = True
self.blocks = nn.Sequential(*block_list)
self.avg_pool = AdaptiveAvgPool2D(1)
self.avg_pool = AdaptiveAvgPool2D(1, data_format=data_format)
self.flatten = nn.Flatten()
self.avg_pool_channels = self.num_channels[-1] * 2
stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0)
......@@ -319,7 +336,13 @@ class ResNet(TheseusLayer):
self.class_num,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
self.data_format = data_format
def forward(self, x):
with paddle.static.amp.fp16_guard():
if self.data_format == "NHWC":
x = paddle.transpose(x, [0, 2, 3, 1])
x.stop_gradient = True
x = self.stem(x)
x = self.max_pool(x)
x = self.blocks(x)
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 120
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_channel: &image_channel 4
image_shape: [*image_channel, 224, 224]
save_inference_dir: ./inference
# training model under @to_static
to_static: False
# mixed precision training
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
use_pure_fp16: &use_pure_fp16 True
# model architecture
Arch:
name: ResNet50
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
multi_precision: False # *use_pure_fp16
lr:
name: Piecewise
learning_rate: 0.1
decay_epochs: [30, 60, 90]
values: [0.1, 0.01, 0.001, 0.0001]
regularizer:
name: 'L2'
coeff: 0.0001
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
output_fp16: *use_pure_fp16
channel_num: *image_channel
sampler:
name: DistributedBatchSampler
batch_size: 32
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
output_fp16: *use_pure_fp16
channel_num: *image_channel
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/whl/demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
output_fp16: *use_pure_fp16
channel_num: *image_channel
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 200
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_channel: &image_channel 4
image_shape: [*image_channel, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: SE_ResNeXt101_32x4d
class_num: 1000
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
epsilon: 0.1
Eval:
- CELoss:
weight: 1.0
# mixed precision training
AMP:
scale_loss: 128.0
use_dynamic_loss_scaling: True
use_pure_fp16: &use_pure_fp16 True
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.1
regularizer:
name: 'L2'
coeff: 0.00007
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
output_fp16: *use_pure_fp16
channel_num: *image_channel
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
output_fp16: *use_pure_fp16
channel_num: *image_channel
sampler:
name: BatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/whl/demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
output_fp16: *use_pure_fp16
channel_num: *image_channel
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
\ No newline at end of file
......@@ -60,6 +60,7 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader
return dali_dataloader(config, mode, paddle.device.get_device(), seed)
config_dataset = config[mode]['dataset']
config_dataset = copy.deepcopy(config_dataset)
dataset_name = config_dataset.pop('name')
......@@ -74,10 +75,6 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None):
# build sampler
config_sampler = config[mode]['sampler']
#config_sampler["batch_size"] = config_sampler[
# "batch_size"] // paddle.distributed.get_world_size()
#assert config_sampler[
# "batch_size"] >= 1, "The batch_size should be larger than gpu number."
if "name" not in config_sampler:
batch_sampler = None
batch_size = config_sampler["batch_size"]
......
......@@ -148,7 +148,6 @@ def dali_dataloader(config, mode, device, seed=None):
assert "gpu" in device, "gpu training is required for DALI"
device_id = int(device.split(':')[1])
config_dataloader = config[mode]
# mode = 'train' if mode.lower() == 'train' else 'eval'
seed = 42 if seed is None else seed
ops = [
list(x.keys())[0]
......@@ -160,6 +159,7 @@ def dali_dataloader(config, mode, device, seed=None):
support_ops_eval = [
"DecodeImage", "ResizeImage", "CropImage", "NormalizeImage"
]
if mode.lower() == 'train':
assert set(ops) == set(
support_ops_train
......@@ -171,6 +171,14 @@ def dali_dataloader(config, mode, device, seed=None):
), "The supported trasform_ops for eval_dataset in dali is : {}".format(
",".join(support_ops_eval))
normalize_ops = [
op for op in config_dataloader["dataset"]["transform_ops"]
if "NormalizeImage" in op
][0]["NormalizeImage"]
channel_num = normalize_ops.get("channel_num", 3)
output_dtype = types.FLOAT16 if normalize_ops.get("output_fp16",
False) else types.FLOAT
env = os.environ
# assert float(env.get('FLAGS_fraction_of_gpu_memory_to_use', 0.92)) < 0.9, \
# "Please leave enough GPU memory for DALI workspace, e.g., by setting" \
......@@ -179,9 +187,6 @@ def dali_dataloader(config, mode, device, seed=None):
gpu_num = paddle.distributed.get_world_size()
batch_size = config_dataloader["sampler"]["batch_size"]
# assert batch_size % gpu_num == 0, \
# "batch size must be multiple of number of devices"
# batch_size = batch_size // gpu_num
file_root = config_dataloader["dataset"]["image_root"]
file_list = config_dataloader["dataset"]["cls_label_path"]
......@@ -195,15 +200,9 @@ def dali_dataloader(config, mode, device, seed=None):
INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4
}
output_dtype = (types.FLOAT16 if 'AMP' in config and
config.AMP.get("use_pure_fp16", False) else types.FLOAT)
assert interp in interp_map, "interpolation method not supported by DALI"
interp = interp_map[interp]
pad_output = False
image_shape = config.get("image_shape", None)
if image_shape and image_shape[0] == 4:
pad_output = True
pad_output = channel_num == 4
transforms = {
k: v
......@@ -218,6 +217,10 @@ def dali_dataloader(config, mode, device, seed=None):
mean = [v / scale for v in mean]
std = [v / scale for v in std]
sampler_name = config_dataloader["sampler"].get("name",
"DistributedBatchSampler")
assert sampler_name in ["DistributedBatchSampler", "BatchSampler"]
if mode.lower() == "train":
resize_shorter = 256
crop = transforms["RandCropImage"]["size"]
......@@ -279,10 +282,11 @@ def dali_dataloader(config, mode, device, seed=None):
else:
resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
crop = transforms["CropImage"]["size"]
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env:
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env and sampler_name == "DistributedBatchSampler":
shard_id = int(env['PADDLE_TRAINER_ID'])
num_shards = int(env['PADDLE_TRAINERS_NUM'])
device_id = int(env['FLAGS_selected_gpus'])
pipe = HybridValPipe(
file_root,
file_list,
......
......@@ -197,14 +197,26 @@ class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
def __init__(self, scale=None, mean=None, std=None, order='chw'):
def __init__(self,
scale=None,
mean=None,
std=None,
order='chw',
output_fp16=False,
channel_num=3):
if isinstance(scale, str):
scale = eval(scale)
assert channel_num in [
3, 4
], "channel number of input image should be set to 3 or 4."
self.channel_num = channel_num
self.output_dtype = 'float16' if output_fp16 else 'float32'
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
self.order = order
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
......@@ -215,7 +227,20 @@ class NormalizeImage(object):
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
return (img.astype('float32') * self.scale - self.mean) / self.std
img = (img.astype('float32') * self.scale - self.mean) / self.std
if self.channel_num == 4:
img_h = img.shape[1] if self.order == 'chw' else img.shape[0]
img_w = img.shape[2] if self.order == 'chw' else img.shape[1]
pad_zeros = np.zeros(
(1, img_h, img_w)) if self.order == 'chw' else np.zeros(
(img_h, img_w, 1))
img = (np.concatenate(
(img, pad_zeros), axis=0)
if self.order == 'chw' else np.concatenate(
(img, pad_zeros), axis=2))
return img.astype(self.output_dtype)
class ToCHWImage(object):
......
......@@ -41,7 +41,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch):
return lr
def build_optimizer(config, epochs, step_each_epoch, parameters):
def build_optimizer(config, epochs, step_each_epoch, parameters=None):
config = copy.deepcopy(config)
# step1 build lr
lr = build_lr_scheduler(config.pop('lr'), epochs, step_each_epoch)
......
......@@ -33,12 +33,14 @@ class Momentum(object):
learning_rate,
momentum,
weight_decay=None,
grad_clip=None):
grad_clip=None,
multi_precision=False):
super(Momentum, self).__init__()
self.learning_rate = learning_rate
self.momentum = momentum
self.weight_decay = weight_decay
self.grad_clip = grad_clip
self.multi_precision = multi_precision
def __call__(self, parameters):
opt = optim.Momentum(
......@@ -46,6 +48,7 @@ class Momentum(object):
momentum=self.momentum,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
multi_precision=self.multi_precision,
parameters=parameters)
return opt
......@@ -60,7 +63,8 @@ class Adam(object):
weight_decay=None,
grad_clip=None,
name=None,
lazy_mode=False):
lazy_mode=False,
multi_precision=False):
self.learning_rate = learning_rate
self.beta1 = beta1
self.beta2 = beta2
......@@ -71,6 +75,7 @@ class Adam(object):
self.grad_clip = grad_clip
self.name = name
self.lazy_mode = lazy_mode
self.multi_precision = multi_precision
def __call__(self, parameters):
opt = optim.Adam(
......@@ -82,6 +87,7 @@ class Adam(object):
grad_clip=self.grad_clip,
name=self.name,
lazy_mode=self.lazy_mode,
multi_precision=self.multi_precision,
parameters=parameters)
return opt
......@@ -104,7 +110,8 @@ class RMSProp(object):
rho=0.95,
epsilon=1e-6,
weight_decay=None,
grad_clip=None):
grad_clip=None,
multi_precision=False):
super(RMSProp, self).__init__()
self.learning_rate = learning_rate
self.momentum = momentum
......@@ -112,6 +119,7 @@ class RMSProp(object):
self.epsilon = epsilon
self.weight_decay = weight_decay
self.grad_clip = grad_clip
self.multi_precision = multi_precision
def __call__(self, parameters):
opt = optim.RMSProp(
......@@ -121,5 +129,6 @@ class RMSProp(object):
epsilon=self.epsilon,
weight_decay=self.weight_decay,
grad_clip=self.grad_clip,
multi_precision=self.multi_precision,
parameters=parameters)
return opt
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import numpy as np
from collections import OrderedDict
import paddle
import paddle.nn.functional as F
from paddle.distributed import fleet
from paddle.distributed.fleet import DistributedStrategy
# from ppcls.optimizer import OptimizerBuilder
# from ppcls.optimizer.learning_rate import LearningRateBuilder
from ppcls.arch import build_model
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.optimizer import build_lr_scheduler
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
def create_feeds(image_shape, use_mix=None, dtype="float32"):
"""
Create feeds as model input
Args:
image_shape(list[int]): model input shape, such as [3, 224, 224]
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
feeds(dict): dict of model input variables
"""
feeds = OrderedDict()
feeds['data'] = paddle.static.data(
name="data", shape=[None] + image_shape, dtype=dtype)
if use_mix:
feeds['y_a'] = paddle.static.data(
name="y_a", shape=[None, 1], dtype="int64")
feeds['y_b'] = paddle.static.data(
name="y_b", shape=[None, 1], dtype="int64")
feeds['lam'] = paddle.static.data(
name="lam", shape=[None, 1], dtype=dtype)
else:
feeds['label'] = paddle.static.data(
name="label", shape=[None, 1], dtype="int64")
return feeds
def create_fetchs(out,
feeds,
architecture,
topk=5,
epsilon=None,
use_mix=False,
config=None,
mode="Train"):
"""
Create fetchs as model outputs(included loss and measures),
will call create_loss and create_metric(if use_mix).
Args:
out(variable): model output variable
feeds(dict): dict of model input variables.
If use mix_up, it will not include label.
architecture(dict): architecture information,
name(such as ResNet50) is needed
topk(int): usually top5
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
config(dict): model config
Returns:
fetchs(dict): dict of model outputs(included loss and measures)
"""
fetchs = OrderedDict()
# build loss
# TODO(littletomatodonkey): support mix training
if use_mix:
y_a = paddle.reshape(feeds['y_a'], [-1, 1])
y_b = paddle.reshape(feeds['y_b'], [-1, 1])
lam = paddle.reshape(feeds['lam'], [-1, 1])
else:
target = paddle.reshape(feeds['label'], [-1, 1])
loss_func = build_loss(config["Loss"][mode])
# TODO: support mix training
loss_dict = loss_func(out, target)
loss_out = loss_dict["loss"]
# if "AMP" in config and config.AMP.get("use_pure_fp16", False):
# loss_out = loss_out.astype("float16")
# if use_mix:
# return loss_func(out, feed_y_a, feed_y_b, feed_lam)
# else:
# return loss_func(out, target)
fetchs['loss'] = (loss_out, AverageMeter('loss', '7.4f', need_avg=True))
assert use_mix is False
# build metric
if not use_mix:
metric_func = build_metrics(config["Metric"][mode])
metric_dict = metric_func(out, target)
for key in metric_dict:
if mode != "Train" and paddle.distributed.get_world_size() > 1:
paddle.distributed.all_reduce(
metric_dict[key], op=paddle.distributed.ReduceOp.SUM)
metric_dict[key] = metric_dict[
key] / paddle.distributed.get_world_size()
fetchs[key] = (metric_dict[key], AverageMeter(
key, '7.4f', need_avg=True))
return fetchs
def create_optimizer(config, step_each_epoch):
# create learning_rate instance
optimizer, lr_sch = build_optimizer(
config["Optimizer"], config["Global"]["epochs"], step_each_epoch)
return optimizer, lr_sch
def create_strategy(config):
"""
Create build strategy and exec strategy.
Args:
config(dict): config
Returns:
build_strategy: build strategy
exec_strategy: exec strategy
"""
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = (
10000
if 'AMP' in config and config.AMP.get("use_pure_fp16", False) else 10)
fuse_op = True if 'AMP' in config else False
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
enable_addto = config.get('enable_addto', fuse_op)
build_strategy.fuse_bn_act_ops = fuse_bn_act_ops
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
build_strategy.enable_addto = enable_addto
return build_strategy, exec_strategy
def dist_optimizer(config, optimizer):
"""
Create a distributed optimizer based on a normal optimizer
Args:
config(dict):
optimizer(): a normal optimizer
Returns:
optimizer: a distributed optimizer
"""
build_strategy, exec_strategy = create_strategy(config)
dist_strategy = DistributedStrategy()
dist_strategy.execution_strategy = exec_strategy
dist_strategy.build_strategy = build_strategy
dist_strategy.nccl_comm_num = 1
dist_strategy.fuse_all_reduce_ops = True
dist_strategy.fuse_grad_size_in_MB = 16
optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
return optimizer
def mixed_precision_optimizer(config, optimizer):
if 'AMP' in config:
amp_cfg = config.AMP if config.AMP else dict()
scale_loss = amp_cfg.get('scale_loss', 1.0)
use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling',
False)
use_pure_fp16 = amp_cfg.get('use_pure_fp16', False)
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling,
use_pure_fp16=use_pure_fp16,
use_fp16_guard=True)
return optimizer
def build(config,
main_prog,
startup_prog,
step_each_epoch=100,
is_train=True,
is_distributed=True):
"""
Build a program using a model and an optimizer
1. create feeds
2. create a dataloader
3. create a model
4. create fetchs
5. create an optimizer
Args:
config(dict): config
main_prog(): main program
startup_prog(): startup program
is_train(bool): train or eval
is_distributed(bool): whether to use distributed training method
Returns:
dataloader(): a bridge between the model and the data
fetchs(dict): dict of model outputs(included loss and measures)
"""
with paddle.static.program_guard(main_prog, startup_prog):
with paddle.utils.unique_name.guard():
mode = "Train" if is_train else "Eval"
use_mix = "batch_transform_ops" in config["DataLoader"][mode][
"dataset"]
use_dali = config["Global"].get('use_dali', False)
feeds = create_feeds(
config["Global"]["image_shape"],
use_mix=use_mix,
dtype="float32")
# build model
# data_format should be assigned in arch-dict
input_image_channel = config["Global"]["image_shape"][
0] # default as [3, 224, 224]
if input_image_channel != 3:
logger.warning(
"Input image channel is changed to {}, maybe for better speed-up".
format(input_image_channel))
config["Arch"]["input_image_channel"] = input_image_channel
model = build_model(config["Arch"])
out = model(feeds["data"])
# end of build model
fetchs = create_fetchs(
out,
feeds,
config["Arch"],
epsilon=config.get('ls_epsilon'),
use_mix=use_mix,
config=config,
mode=mode)
lr_scheduler = None
optimizer = None
if is_train:
optimizer, lr_scheduler = build_optimizer(
config["Optimizer"], config["Global"]["epochs"],
step_each_epoch)
optimizer = mixed_precision_optimizer(config, optimizer)
if is_distributed:
optimizer = dist_optimizer(config, optimizer)
optimizer.minimize(fetchs['loss'][0])
return fetchs, lr_scheduler, feeds, optimizer
def compile(config, program, loss_name=None, share_prog=None):
"""
Compile the program
Args:
config(dict): config
program(): the program which is wrapped by
loss_name(str): loss name
share_prog(): the shared program, used for evaluation during training
Returns:
compiled_program(): a compiled program
"""
build_strategy, exec_strategy = create_strategy(config)
compiled_program = paddle.static.CompiledProgram(
program).with_data_parallel(
share_vars_from=share_prog,
loss_name=loss_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
return compiled_program
total_step = 0
def run(dataloader,
exe,
program,
feeds,
fetchs,
epoch=0,
mode='train',
config=None,
vdl_writer=None,
lr_scheduler=None):
"""
Feed data to the model and fetch the measures and loss
Args:
dataloader(paddle io dataloader):
exe():
program():
fetchs(dict): dict of measures and the loss
epoch(int): epoch of training or evaluation
model(str): log only
Returns:
"""
fetch_list = [f[0] for f in fetchs.values()]
metric_dict = OrderedDict([("lr", AverageMeter(
'lr', 'f', postfix=",", need_avg=False))])
for k in fetchs:
metric_dict[k] = fetchs[k][1]
metric_dict["batch_time"] = AverageMeter(
'batch_cost', '.5f', postfix=" s,")
metric_dict["reader_time"] = AverageMeter(
'reader_cost', '.5f', postfix=" s,")
for m in metric_dict.values():
m.reset()
use_dali = config["Global"].get('use_dali', False)
tic = time.time()
if not use_dali:
dataloader = dataloader()
idx = 0
batch_size = None
while True:
# The DALI maybe raise RuntimeError for some particular images, such as ImageNet1k/n04418357_26036.JPEG
try:
batch = next(dataloader)
except StopIteration:
break
except RuntimeError:
logger.warning(
"Except RuntimeError when reading data from dataloader, try to read once again..."
)
continue
idx += 1
# ignore the warmup iters
if idx == 5:
metric_dict["batch_time"].reset()
metric_dict["reader_time"].reset()
metric_dict['reader_time'].update(time.time() - tic)
if use_dali:
batch_size = batch[0]["data"].shape()[0]
feed_dict = batch[0]
else:
batch_size = batch[0].shape()[0]
feed_dict = {
key.name: batch[idx]
for idx, key in enumerate(feeds.values())
}
metrics = exe.run(program=program,
feed=feed_dict,
fetch_list=fetch_list)
for name, m in zip(fetchs.keys(), metrics):
metric_dict[name].update(np.mean(m), batch_size)
metric_dict["batch_time"].update(time.time() - tic)
if mode == "train":
metric_dict['lr'].update(lr_scheduler.get_lr())
fetchs_str = ' '.join([
str(metric_dict[key].mean)
if "time" in key else str(metric_dict[key].value)
for key in metric_dict
])
ips_info = " ips: {:.5f} images/sec.".format(
batch_size / metric_dict["batch_time"].avg)
fetchs_str += ips_info
if lr_scheduler is not None:
lr_scheduler.step()
if vdl_writer:
global total_step
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
total_step += 1
if mode == 'eval':
if idx % config.get('print_interval', 10) == 0:
logger.info("{:s} step:{:<4d} {:s}".format(mode, idx,
fetchs_str))
else:
epoch_str = "epoch:{:<3d}".format(epoch)
step_str = "{:s} step:{:<4d}".format(mode, idx)
if idx % config.get('print_interval', 10) == 0:
logger.info("{:s} {:s} {:s}".format(epoch_str, step_str,
fetchs_str))
tic = time.time()
end_str = ' '.join([str(m.mean) for m in metric_dict.values()] +
[metric_dict["batch_time"].total])
ips_info = "ips: {:.5f} images/sec.".format(
batch_size * metric_dict["batch_time"].count /
metric_dict["batch_time"].sum)
if mode == 'eval':
logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
else:
end_epoch_str = "END epoch:{:<3d}".format(epoch)
logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
ips_info))
if use_dali:
dataloader.reset()
# return top1_acc in order to save the best model
if mode == 'eval':
return fetchs["top1"][1].avg
#!/usr/bin/env bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export FLAGS_fraction_of_gpu_memory_to_use=0.80
python3.7 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
ppcls/static//train.py \
-c ./ppcls/configs/ImageNet/ResNet/ResNet50_fp16.yaml \
-o Global.use_dali=True
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import errno
import os
import re
import shutil
import tempfile
import paddle
from ppcls.utils import logger
__all__ = ['init_model', 'save_model']
def _mkdir_if_not_exist(path):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
'be happy if some process has already created {}'.format(
path))
else:
raise OSError('Failed to mkdir {}'.format(path))
def _load_state(path):
if os.path.exists(path + '.pdopt'):
# XXX another hack to ignore the optimizer state
tmp = tempfile.mkdtemp()
dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
shutil.copy(path + '.pdparams', dst + '.pdparams')
state = paddle.static.load_program_state(dst)
shutil.rmtree(tmp)
else:
state = paddle.static.load_program_state(path)
return state
def load_params(exe, prog, path, ignore_params=None):
"""
Load model from the given path.
Args:
exe (fluid.Executor): The fluid.Executor object.
prog (fluid.Program): load weight to which Program object.
path (string): URL string or loca model path.
ignore_params (list): ignore variable to load when finetuning.
It can be specified by finetune_exclude_pretrained_params
and the usage can refer to the document
docs/advanced_tutorials/TRANSFER_LEARNING.md
"""
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
logger.info("Loading parameters from {}...".format(path))
ignore_set = set()
state = _load_state(path)
# ignore the parameter which mismatch the shape
# between the model and pretrain weight.
all_var_shape = {}
for block in prog.blocks:
for param in block.all_parameters():
all_var_shape[param.name] = param.shape
ignore_set.update([
name for name, shape in all_var_shape.items()
if name in state and shape != state[name].shape
])
if ignore_params:
all_var_names = [var.name for var in prog.list_vars()]
ignore_list = filter(
lambda var: any([re.match(name, var) for name in ignore_params]),
all_var_names)
ignore_set.update(list(ignore_list))
if len(ignore_set) > 0:
for k in ignore_set:
if k in state:
logger.warning(
'variable {} is already excluded automatically'.format(k))
del state[k]
paddle.static.set_program_state(prog, state)
def init_model(config, program, exe):
"""
load model from checkpoint or pretrained_model
"""
checkpoints = config.get('checkpoints')
if checkpoints:
paddle.static.load(program, checkpoints, exe)
logger.info("Finish initing model from {}".format(checkpoints))
return
pretrained_model = config.get('pretrained_model')
if pretrained_model:
if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model]
for pretrain in pretrained_model:
load_params(exe, program, pretrain)
logger.info("Finish initing model from {}".format(pretrained_model))
def save_model(program, model_path, epoch_id, prefix='ppcls'):
"""
save model to the target path
"""
if paddle.distributed.get_rank() != 0:
return
model_path = os.path.join(model_path, str(epoch_id))
_mkdir_if_not_exist(model_path)
model_prefix = os.path.join(model_path, prefix)
paddle.static.save(program, model_prefix)
logger.info("Already save model in {}".format(model_path))
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
import paddle
from paddle.distributed import fleet
from visualdl import LogWriter
from ppcls.data import build_dataloader
from ppcls.utils.config import get_config, print_config
from ppcls.utils import logger
from ppcls.utils.logger import init_logger
from ppcls.static.save_load import init_model, save_model
from ppcls.static import program
def parse_args():
parser = argparse.ArgumentParser("PaddleClas train script")
parser.add_argument(
'-c',
'--config',
type=str,
default='configs/ResNet/ResNet50.yaml',
help='config file path')
parser.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
args = parser.parse_args()
return args
def main(args):
"""
all the config of training paradigm should be in config["Global"]
"""
config = get_config(args.config, overrides=args.override, show=False)
global_config = config["Global"]
mode = "train"
log_file = os.path.join(global_config['output_dir'],
config["Arch"]["name"], f"{mode}.log")
init_logger(name='root', log_file=log_file)
print_config(config)
if global_config.get("is_distributed", True):
fleet.init(is_collective=True)
# assign the device
use_gpu = global_config.get("use_gpu", True)
# amp related config
if 'AMP' in config:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_exhaustive_search': "1",
'FLAGS_conv_workspace_size_limit': "1500",
'FLAGS_cudnn_batchnorm_spatial_persistent': "1",
'FLAGS_max_indevice_grad_add': "8",
"FLAGS_cudnn_batchnorm_spatial_persistent": "1",
}
for k in AMP_RELATED_FLAGS_SETTING:
os.environ[k] = AMP_RELATED_FLAGS_SETTING[k]
use_xpu = global_config.get("use_xpu", False)
assert (
use_gpu and use_xpu
) is not True, "gpu and xpu can not be true in the same time in static mode!"
if use_gpu:
device = paddle.set_device('gpu')
elif use_xpu:
device = paddle.set_device('xpu')
else:
device = paddle.set_device('cpu')
# visualDL
vdl_writer = None
if global_config["use_visualdl"]:
vdl_dir = os.path.join(global_config["output_dir"], "vdl")
vdl_writer = LogWriter(vdl_dir)
# build dataloader
eval_dataloader = None
use_dali = global_config.get('use_dali', False)
train_dataloader = build_dataloader(
config["DataLoader"], "Train", device=device, use_dali=use_dali)
if global_config["eval_during_train"]:
eval_dataloader = build_dataloader(
config["DataLoader"], "Eval", device=device, use_dali=use_dali)
step_each_epoch = len(train_dataloader)
# startup_prog is used to do some parameter init work,
# and train prog is used to hold the network
startup_prog = paddle.static.Program()
train_prog = paddle.static.Program()
best_top1_acc = 0.0 # best top1 acc record
train_fetchs, lr_scheduler, train_feeds, optimizer = program.build(
config,
train_prog,
startup_prog,
step_each_epoch=step_each_epoch,
is_train=True,
is_distributed=global_config.get("is_distributed", True))
if global_config["eval_during_train"]:
eval_prog = paddle.static.Program()
eval_fetchs, _, eval_feeds, _ = program.build(
config,
eval_prog,
startup_prog,
is_train=False,
is_distributed=global_config.get("is_distributed", True))
# clone to prune some content which is irrelevant in eval_prog
eval_prog = eval_prog.clone(for_test=True)
# create the "Executor" with the statement of which device
exe = paddle.static.Executor(device)
# Parameter initialization
exe.run(startup_prog)
# load pretrained models or checkpoints
init_model(global_config, train_prog, exe)
if 'AMP' in config and config.AMP.get("use_pure_fp16", False):
optimizer.amp_init(
device,
scope=paddle.static.global_scope(),
test_program=eval_prog
if global_config["eval_during_train"] else None)
if not global_config.get("is_distributed", True):
compiled_train_prog = program.compile(
config, train_prog, loss_name=train_fetchs["loss"][0].name)
else:
compiled_train_prog = train_prog
if eval_dataloader is not None:
compiled_eval_prog = program.compile(config, eval_prog)
for epoch_id in range(global_config["epochs"]):
# 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_feeds,
train_fetchs, epoch_id, 'train', config, vdl_writer,
lr_scheduler)
# 2. evaate with eval dataset
if global_config["eval_during_train"] and epoch_id % global_config[
"eval_interval"] == 0:
top1_acc = program.run(eval_dataloader, exe, compiled_eval_prog,
eval_feeds, eval_fetchs, epoch_id, "eval",
config)
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, epoch_id)
logger.info(message)
if epoch_id % global_config["save_interval"] == 0:
model_path = os.path.join(global_config["output_dir"],
config["Arch"]["name"])
save_model(train_prog, model_path, "best_model")
# 3. save the persistable model
if epoch_id % global_config["save_interval"] == 0:
model_path = os.path.join(global_config["output_dir"],
config["Arch"]["name"])
save_model(train_prog, model_path, epoch_id)
if __name__ == '__main__':
paddle.enable_static()
args = parse_args()
main(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册