未验证 提交 73004f78 编写于 作者: W Walter 提交者: GitHub

add fp16 amp training and dali (#993)

* add dygraph amp train
上级 8598b46d
...@@ -53,10 +53,13 @@ def create_operators(params): ...@@ -53,10 +53,13 @@ def create_operators(params):
return ops return ops
def build_dataloader(config, mode, device, seed=None): def build_dataloader(config, mode, device, use_dali=False, seed=None):
assert mode in ['Train', 'Eval', 'Test', 'Gallery', 'Query' assert mode in ['Train', 'Eval', 'Test', 'Gallery', 'Query'
], "Mode should be Train, Eval, Test, Gallery, Query" ], "Mode should be Train, Eval, Test, Gallery, Query"
# build dataset # build dataset
if use_dali:
from ppcls.data.dataloader.dali import dali_dataloader
return dali_dataloader(config, mode, paddle.device.get_device(), seed)
config_dataset = config[mode]['dataset'] config_dataset = config[mode]['dataset']
config_dataset = copy.deepcopy(config_dataset) config_dataset = copy.deepcopy(config_dataset)
dataset_name = config_dataset.pop('name') dataset_name = config_dataset.pop('name')
...@@ -71,6 +74,10 @@ def build_dataloader(config, mode, device, seed=None): ...@@ -71,6 +74,10 @@ def build_dataloader(config, mode, device, seed=None):
# build sampler # build sampler
config_sampler = config[mode]['sampler'] config_sampler = config[mode]['sampler']
#config_sampler["batch_size"] = config_sampler[
# "batch_size"] // paddle.distributed.get_world_size()
#assert config_sampler[
# "batch_size"] >= 1, "The batch_size should be larger than gpu number."
if "name" not in config_sampler: if "name" not in config_sampler:
batch_sampler = None batch_sampler = None
batch_size = config_sampler["batch_size"] batch_size = config_sampler["batch_size"]
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import copy
import os
import numpy as np
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import paddle
from nvidia.dali import fn
from nvidia.dali.pipeline import Pipeline
from nvidia.dali.plugin.base_iterator import LastBatchPolicy
from nvidia.dali.plugin.paddle import DALIGenericIterator
class HybridTrainPipe(Pipeline):
def __init__(self,
file_root,
file_list,
batch_size,
resize_shorter,
crop,
min_area,
lower,
upper,
interp,
mean,
std,
device_id,
shard_id=0,
num_shards=1,
random_shuffle=True,
num_threads=4,
seed=42,
pad_output=False,
output_dtype=types.FLOAT,
dataset='Train'):
super(HybridTrainPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed)
self.input = ops.readers.File(
file_root=file_root,
file_list=file_list,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=random_shuffle)
# set internal nvJPEG buffers size to handle full-sized ImageNet images
# without additional reallocations
device_memory_padding = 211025920
host_memory_padding = 140544512
self.decode = ops.decoders.ImageRandomCrop(
device='mixed',
output_type=types.DALIImageType.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
random_aspect_ratio=[lower, upper],
random_area=[min_area, 1.0],
num_attempts=100)
self.res = ops.Resize(
device='gpu', resize_x=crop, resize_y=crop, interp_type=interp)
self.cmnp = ops.CropMirrorNormalize(
device="gpu",
dtype=output_dtype,
output_layout='CHW',
crop=(crop, crop),
mean=mean,
std=std,
pad_output=pad_output)
self.coin = ops.random.CoinFlip(probability=0.5)
self.to_int64 = ops.Cast(dtype=types.DALIDataType.INT64, device="gpu")
def define_graph(self):
rng = self.coin()
jpegs, labels = self.input(name="Reader")
images = self.decode(jpegs)
images = self.res(images)
output = self.cmnp(images.gpu(), mirror=rng)
return [output, self.to_int64(labels.gpu())]
def __len__(self):
return self.epoch_size("Reader")
class HybridValPipe(Pipeline):
def __init__(self,
file_root,
file_list,
batch_size,
resize_shorter,
crop,
interp,
mean,
std,
device_id,
shard_id=0,
num_shards=1,
random_shuffle=False,
num_threads=4,
seed=42,
pad_output=False,
output_dtype=types.FLOAT):
super(HybridValPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed)
self.input = ops.readers.File(
file_root=file_root,
file_list=file_list,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=random_shuffle)
self.decode = ops.decoders.Image(device="mixed")
self.res = ops.Resize(
device="gpu", resize_shorter=resize_shorter, interp_type=interp)
self.cmnp = ops.CropMirrorNormalize(
device="gpu",
dtype=output_dtype,
output_layout='CHW',
crop=(crop, crop),
mean=mean,
std=std,
pad_output=pad_output)
self.to_int64 = ops.Cast(dtype=types.DALIDataType.INT64, device="gpu")
def define_graph(self):
jpegs, labels = self.input(name="Reader")
images = self.decode(jpegs)
images = self.res(images)
output = self.cmnp(images)
return [output, self.to_int64(labels.gpu())]
def __len__(self):
return self.epoch_size("Reader")
def dali_dataloader(config, mode, device, seed=None):
assert "gpu" in device, "gpu training is required for DALI"
device_id = int(device.split(':')[1])
config_dataloader = config[mode]
# mode = 'train' if mode.lower() == 'train' else 'eval'
seed = 42 if seed is None else seed
ops = [
list(x.keys())[0]
for x in config_dataloader["dataset"]["transform_ops"]
]
support_ops_train = [
"DecodeImage", "NormalizeImage", "RandFlipImage", "RandCropImage"
]
support_ops_eval = [
"DecodeImage", "ResizeImage", "CropImage", "NormalizeImage"
]
if mode.lower() == 'train':
assert set(ops) == set(
support_ops_train
), "The supported trasform_ops for train_dataset in dali is : {}".format(
",".join(support_ops_train))
else:
assert set(ops) == set(
support_ops_eval
), "The supported trasform_ops for eval_dataset in dali is : {}".format(
",".join(support_ops_eval))
env = os.environ
# assert float(env.get('FLAGS_fraction_of_gpu_memory_to_use', 0.92)) < 0.9, \
# "Please leave enough GPU memory for DALI workspace, e.g., by setting" \
# " `export FLAGS_fraction_of_gpu_memory_to_use=0.8`"
gpu_num = paddle.distributed.get_world_size()
batch_size = config_dataloader["sampler"]["batch_size"]
# assert batch_size % gpu_num == 0, \
# "batch size must be multiple of number of devices"
# batch_size = batch_size // gpu_num
file_root = config_dataloader["dataset"]["image_root"]
file_list = config_dataloader["dataset"]["cls_label_path"]
interp = 1 # settings.interpolation or 1 # default to linear
interp_map = {
0: types.DALIInterpType.INTERP_NN, # cv2.INTER_NEAREST
1: types.DALIInterpType.INTERP_LINEAR, # cv2.INTER_LINEAR
2: types.DALIInterpType.INTERP_CUBIC, # cv2.INTER_CUBIC
3: types.DALIInterpType.
INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4
}
output_dtype = (types.FLOAT16 if 'AMP' in config and
config.AMP.get("use_pure_fp16", False) else types.FLOAT)
assert interp in interp_map, "interpolation method not supported by DALI"
interp = interp_map[interp]
pad_output = False
image_shape = config.get("image_shape", None)
if image_shape and image_shape[0] == 4:
pad_output = True
transforms = {
k: v
for d in config_dataloader["dataset"]["transform_ops"]
for k, v in d.items()
}
scale = transforms["NormalizeImage"].get("scale", 1.0 / 255)
scale = eval(scale) if isinstance(scale, str) else scale
mean = transforms["NormalizeImage"].get("mean", [0.485, 0.456, 0.406])
std = transforms["NormalizeImage"].get("std", [0.229, 0.224, 0.225])
mean = [v / scale for v in mean]
std = [v / scale for v in std]
if mode.lower() == "train":
resize_shorter = 256
crop = transforms["RandCropImage"]["size"]
scale = transforms["RandCropImage"].get("scale", [0.08, 1.])
ratio = transforms["RandCropImage"].get("ratio", [3.0 / 4, 4.0 / 3])
min_area = scale[0]
lower = ratio[0]
upper = ratio[1]
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env:
shard_id = int(env['PADDLE_TRAINER_ID'])
num_shards = int(env['PADDLE_TRAINERS_NUM'])
device_id = int(env['FLAGS_selected_gpus'])
pipe = HybridTrainPipe(
file_root,
file_list,
batch_size,
resize_shorter,
crop,
min_area,
lower,
upper,
interp,
mean,
std,
device_id,
shard_id,
num_shards,
seed=seed + shard_id,
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build()
pipelines = [pipe]
# sample_per_shard = len(pipe) // num_shards
else:
pipe = HybridTrainPipe(
file_root,
file_list,
batch_size,
resize_shorter,
crop,
min_area,
lower,
upper,
interp,
mean,
std,
device_id=device_id,
shard_id=0,
num_shards=1,
seed=seed,
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build()
pipelines = [pipe]
# sample_per_shard = len(pipelines[0])
return DALIGenericIterator(
pipelines, ['data', 'label'], reader_name='Reader')
else:
resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
crop = transforms["CropImage"]["size"]
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env:
shard_id = int(env['PADDLE_TRAINER_ID'])
num_shards = int(env['PADDLE_TRAINERS_NUM'])
device_id = int(env['FLAGS_selected_gpus'])
pipe = HybridValPipe(
file_root,
file_list,
batch_size,
resize_shorter,
crop,
interp,
mean,
std,
device_id=device_id,
shard_id=shard_id,
num_shards=num_shards,
pad_output=pad_output,
output_dtype=output_dtype)
else:
pipe = HybridValPipe(
file_root,
file_list,
batch_size,
resize_shorter,
crop,
interp,
mean,
std,
device_id=device_id,
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build()
return DALIGenericIterator(
[pipe], ['data', 'label'], reader_name="Reader")
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import os import os
import sys import sys
import numpy as np import numpy as np
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
...@@ -103,10 +104,19 @@ class Trainer(object): ...@@ -103,10 +104,19 @@ class Trainer(object):
self.query_dataloader = None self.query_dataloader = None
self.eval_mode = self.config["Global"].get("eval_mode", self.eval_mode = self.config["Global"].get("eval_mode",
"classification") "classification")
self.amp = True if "AMP" in self.config else False
if self.amp and self.config["AMP"] is not None:
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get(
"use_dynamic_loss_scaling", False)
else:
self.scale_loss = 1.0
self.use_dynamic_loss_scaling = False
self.train_loss_func = None self.train_loss_func = None
self.eval_loss_func = None self.eval_loss_func = None
self.train_metric_func = None self.train_metric_func = None
self.eval_metric_func = None self.eval_metric_func = None
self.use_dali = self.config['Global'].get("use_dali", False)
def train(self): def train(self):
# build train loss and metric info # build train loss and metric info
...@@ -121,8 +131,8 @@ class Trainer(object): ...@@ -121,8 +131,8 @@ class Trainer(object):
self.train_metric_func = build_metrics(metric_config) self.train_metric_func = build_metrics(metric_config)
if self.train_dataloader is None: if self.train_dataloader is None:
self.train_dataloader = build_dataloader(self.config["DataLoader"], self.train_dataloader = build_dataloader(
"Train", self.device) self.config["DataLoader"], "Train", self.device, self.use_dali)
step_each_epoch = len(self.train_dataloader) step_each_epoch = len(self.train_dataloader)
...@@ -138,7 +148,7 @@ class Trainer(object): ...@@ -138,7 +148,7 @@ class Trainer(object):
"metric": 0.0, "metric": 0.0,
"epoch": 0, "epoch": 0,
} }
# key: # key:
# val: metrics list word # val: metrics list word
output_info = dict() output_info = dict()
time_info = { time_info = {
...@@ -156,28 +166,46 @@ class Trainer(object): ...@@ -156,28 +166,46 @@ class Trainer(object):
if metric_info is not None: if metric_info is not None:
best_metric.update(metric_info) best_metric.update(metric_info)
# for amp training
if self.amp:
scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
tic = time.time() tic = time.time()
max_iter = len(self.train_dataloader) - 1 if platform.system( max_iter = len(self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader) ) == "Windows" else len(self.train_dataloader)
for epoch_id in range(best_metric["epoch"] + 1, for epoch_id in range(best_metric["epoch"] + 1,
self.config["Global"]["epochs"] + 1): self.config["Global"]["epochs"] + 1):
acc = 0.0 acc = 0.0
for iter_id, batch in enumerate(self.train_dataloader()): train_dataloader = self.train_dataloader if self.use_dali else self.train_dataloader(
)
for iter_id, batch in enumerate(train_dataloader):
if iter_id >= max_iter: if iter_id >= max_iter:
break break
if iter_id == 5: if iter_id == 5:
for key in time_info: for key in time_info:
time_info[key].reset() time_info[key].reset()
time_info["reader_cost"].update(time.time() - tic) time_info["reader_cost"].update(time.time() - tic)
if self.use_dali:
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
batch_size = batch[0].shape[0] batch_size = batch[0].shape[0]
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
global_step += 1 global_step += 1
# image input # image input
if not self.is_rec: if self.amp:
out = self.model(batch[0]) with paddle.amp.auto_cast(custom_black_list={
"flatten_contiguous_range", "greater_than"
}):
out = self.forward(batch)
loss_dict = self.train_loss_func(out, batch[1])
else: else:
out = self.model(batch[0], batch[1]) out = self.forward(batch)
# calc loss # calc loss
if self.config["DataLoader"]["Train"]["dataset"].get( if self.config["DataLoader"]["Train"]["dataset"].get(
"batch_transform_ops", None): "batch_transform_ops", None):
...@@ -200,8 +228,13 @@ class Trainer(object): ...@@ -200,8 +228,13 @@ class Trainer(object):
batch_size) batch_size)
# step opt and lr # step opt and lr
loss_dict["loss"].backward() if self.amp:
optimizer.step() scaled = scaler.scale(loss_dict["loss"])
scaled.backward()
scaler.minimize(optimizer, scaled)
else:
loss_dict["loss"].backward()
optimizer.step()
optimizer.clear_grad() optimizer.clear_grad()
lr_sch.step() lr_sch.step()
...@@ -244,7 +277,8 @@ class Trainer(object): ...@@ -244,7 +277,8 @@ class Trainer(object):
step=global_step, step=global_step,
writer=self.vdl_writer) writer=self.vdl_writer)
tic = time.time() tic = time.time()
if self.use_dali:
self.train_dataloader.reset()
metric_msg = ", ".join([ metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg) "{}: {:.5f}".format(key, output_info[key].avg)
for key in output_info for key in output_info
...@@ -314,7 +348,8 @@ class Trainer(object): ...@@ -314,7 +348,8 @@ class Trainer(object):
if self.eval_mode == "classification": if self.eval_mode == "classification":
if self.eval_dataloader is None: if self.eval_dataloader is None:
self.eval_dataloader = build_dataloader( self.eval_dataloader = build_dataloader(
self.config["DataLoader"], "Eval", self.device) self.config["DataLoader"], "Eval", self.device,
self.use_dali)
if self.eval_metric_func is None: if self.eval_metric_func is None:
metric_config = self.config.get("Metric") metric_config = self.config.get("Metric")
...@@ -328,11 +363,13 @@ class Trainer(object): ...@@ -328,11 +363,13 @@ class Trainer(object):
elif self.eval_mode == "retrieval": elif self.eval_mode == "retrieval":
if self.gallery_dataloader is None: if self.gallery_dataloader is None:
self.gallery_dataloader = build_dataloader( self.gallery_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Gallery", self.device) self.config["DataLoader"]["Eval"], "Gallery", self.device,
self.use_dali)
if self.query_dataloader is None: if self.query_dataloader is None:
self.query_dataloader = build_dataloader( self.query_dataloader = build_dataloader(
self.config["DataLoader"]["Eval"], "Query", self.device) self.config["DataLoader"]["Eval"], "Query", self.device,
self.use_dali)
# build metric info # build metric info
if self.eval_metric_func is None: if self.eval_metric_func is None:
metric_config = self.config.get("Metric", None) metric_config = self.config.get("Metric", None)
...@@ -348,6 +385,13 @@ class Trainer(object): ...@@ -348,6 +385,13 @@ class Trainer(object):
self.model.train() self.model.train()
return eval_result return eval_result
def forward(self, batch):
if not self.is_rec:
out = self.model(batch[0])
else:
out = self.model(batch[0], batch[1])
return out
@paddle.no_grad() @paddle.no_grad()
def eval_cls(self, epoch_id=0): def eval_cls(self, epoch_id=0):
output_info = dict() output_info = dict()
...@@ -361,24 +405,27 @@ class Trainer(object): ...@@ -361,24 +405,27 @@ class Trainer(object):
metric_key = None metric_key = None
tic = time.time() tic = time.time()
eval_dataloader = self.eval_dataloader if self.use_dali else self.eval_dataloader(
)
max_iter = len(self.eval_dataloader) - 1 if platform.system( max_iter = len(self.eval_dataloader) - 1 if platform.system(
) == "Windows" else len(self.eval_dataloader) ) == "Windows" else len(self.eval_dataloader)
for iter_id, batch in enumerate(self.eval_dataloader()): for iter_id, batch in enumerate(eval_dataloader):
if iter_id >= max_iter: if iter_id >= max_iter:
break break
if iter_id == 5: if iter_id == 5:
for key in time_info: for key in time_info:
time_info[key].reset() time_info[key].reset()
if self.use_dali:
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
time_info["reader_cost"].update(time.time() - tic) time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0] batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0]).astype("float32") batch[0] = paddle.to_tensor(batch[0]).astype("float32")
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input # image input
if self.is_rec: out = self.forward(batch)
out = self.model(batch[0], batch[1])
else:
out = self.model(batch[0])
# calc loss # calc loss
if self.eval_loss_func is not None: if self.eval_loss_func is not None:
loss_dict = self.eval_loss_func(out, batch[-1]) loss_dict = self.eval_loss_func(out, batch[-1])
...@@ -426,7 +473,8 @@ class Trainer(object): ...@@ -426,7 +473,8 @@ class Trainer(object):
len(self.eval_dataloader), metric_msg, time_msg, ips_msg)) len(self.eval_dataloader), metric_msg, time_msg, ips_msg))
tic = time.time() tic = time.time()
if self.use_dali:
self.eval_dataloader.reset()
metric_msg = ", ".join([ metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg) "{}: {:.5f}".format(key, output_info[key].avg)
for key in output_info for key in output_info
...@@ -441,7 +489,6 @@ class Trainer(object): ...@@ -441,7 +489,6 @@ class Trainer(object):
def eval_retrieval(self, epoch_id=0): def eval_retrieval(self, epoch_id=0):
self.model.eval() self.model.eval()
cum_similarity_matrix = None
# step1. build gallery # step1. build gallery
gallery_feas, gallery_img_id, gallery_unique_id = self._cal_feature( gallery_feas, gallery_img_id, gallery_unique_id = self._cal_feature(
name='gallery') name='gallery')
...@@ -516,14 +563,20 @@ class Trainer(object): ...@@ -516,14 +563,20 @@ class Trainer(object):
has_unique_id = False has_unique_id = False
max_iter = len(dataloader) - 1 if platform.system( max_iter = len(dataloader) - 1 if platform.system(
) == "Windows" else len(dataloader) ) == "Windows" else len(dataloader)
for idx, batch in enumerate(dataloader( dataloader_tmp = dataloader if self.use_dali else dataloader()
)): # load is very time-consuming for idx, batch in enumerate(
dataloader_tmp): # load is very time-consuming
if idx >= max_iter: if idx >= max_iter:
break break
if idx % self.config["Global"]["print_batch_step"] == 0: if idx % self.config["Global"]["print_batch_step"] == 0:
logger.info( logger.info(
f"{name} feature calculation process: [{idx}/{len(dataloader)}]" f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
) )
if self.use_dali:
batch = [
paddle.to_tensor(batch[0]['data']),
paddle.to_tensor(batch[0]['label'])
]
batch = [paddle.to_tensor(x) for x in batch] batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if len(batch) == 3: if len(batch) == 3:
...@@ -549,7 +602,8 @@ class Trainer(object): ...@@ -549,7 +602,8 @@ class Trainer(object):
all_image_id = paddle.concat([all_image_id, batch[1]]) all_image_id = paddle.concat([all_image_id, batch[1]])
if has_unique_id: if has_unique_id:
all_unique_id = paddle.concat([all_unique_id, batch[2]]) all_unique_id = paddle.concat([all_unique_id, batch[2]])
if self.use_dali:
dataloader_tmp.reset()
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
feat_list = [] feat_list = []
img_id_list = [] img_id_list = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册