“8d4af9c50cad4d233ad12e13ad08881d22fc782c”上不存在“data_backup/1.leetcode/127_最长连续序列/desc.html”
未验证 提交 d7c55d87 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

rm unused static code (#1095)

上级 00455839
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import os
import numpy as np
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.plugin.paddle import DALIGenericIterator
import paddle
from paddle import fluid
class HybridTrainPipe(Pipeline):
def __init__(self,
file_root,
file_list,
batch_size,
resize_shorter,
crop,
min_area,
lower,
upper,
interp,
mean,
std,
device_id,
shard_id=0,
num_shards=1,
random_shuffle=True,
num_threads=4,
seed=42,
pad_output=False,
output_dtype=types.FLOAT):
super(HybridTrainPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed)
self.input = ops.FileReader(
file_root=file_root,
file_list=file_list,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=random_shuffle)
# set internal nvJPEG buffers size to handle full-sized ImageNet images
# without additional reallocations
device_memory_padding = 211025920
host_memory_padding = 140544512
self.decode = ops.ImageDecoderRandomCrop(
device='mixed',
output_type=types.RGB,
device_memory_padding=device_memory_padding,
host_memory_padding=host_memory_padding,
random_aspect_ratio=[lower, upper],
random_area=[min_area, 1.0],
num_attempts=100)
self.res = ops.Resize(
device='gpu', resize_x=crop, resize_y=crop, interp_type=interp)
self.cmnp = ops.CropMirrorNormalize(
device="gpu",
output_dtype=output_dtype,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=mean,
std=std,
pad_output=pad_output)
self.coin = ops.CoinFlip(probability=0.5)
self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")
def define_graph(self):
rng = self.coin()
jpegs, labels = self.input(name="Reader")
images = self.decode(jpegs)
images = self.res(images)
output = self.cmnp(images.gpu(), mirror=rng)
return [output, self.to_int64(labels.gpu())]
def __len__(self):
return self.epoch_size("Reader")
class HybridValPipe(Pipeline):
def __init__(self,
file_root,
file_list,
batch_size,
resize_shorter,
crop,
interp,
mean,
std,
device_id,
shard_id=0,
num_shards=1,
random_shuffle=False,
num_threads=4,
seed=42,
pad_output=False,
output_dtype=types.FLOAT):
super(HybridValPipe, self).__init__(
batch_size, num_threads, device_id, seed=seed)
self.input = ops.FileReader(
file_root=file_root,
file_list=file_list,
shard_id=shard_id,
num_shards=num_shards,
random_shuffle=random_shuffle)
self.decode = ops.ImageDecoder(device="mixed", output_type=types.RGB)
self.res = ops.Resize(
device="gpu", resize_shorter=resize_shorter, interp_type=interp)
self.cmnp = ops.CropMirrorNormalize(
device="gpu",
output_dtype=output_dtype,
output_layout=types.NCHW,
crop=(crop, crop),
image_type=types.RGB,
mean=mean,
std=std,
pad_output=pad_output)
self.to_int64 = ops.Cast(dtype=types.INT64, device="gpu")
def define_graph(self):
jpegs, labels = self.input(name="Reader")
images = self.decode(jpegs)
images = self.res(images)
output = self.cmnp(images)
return [output, self.to_int64(labels.gpu())]
def __len__(self):
return self.epoch_size("Reader")
def build(config, mode='train'):
env = os.environ
assert config.get('use_gpu',
True) == True, "gpu training is required for DALI"
assert not config.get(
'use_aa'), "auto augment is not supported by DALI reader"
assert float(env.get('FLAGS_fraction_of_gpu_memory_to_use', 0.92)) < 0.9, \
"Please leave enough GPU memory for DALI workspace, e.g., by setting" \
" `export FLAGS_fraction_of_gpu_memory_to_use=0.8`"
dataset_config = config[mode.upper()]
gpu_num = paddle.fluid.core.get_cuda_device_count() if (
'PADDLE_TRAINERS_NUM') and (
'PADDLE_TRAINER_ID'
) not in env else int(env.get('PADDLE_TRAINERS_NUM', 0))
batch_size = dataset_config.batch_size
assert batch_size % gpu_num == 0, \
"batch size must be multiple of number of devices"
batch_size = batch_size // gpu_num
file_root = dataset_config.data_dir
file_list = dataset_config.file_list
interp = 1 # settings.interpolation or 1 # default to linear
interp_map = {
0: types.INTERP_NN, # cv2.INTER_NEAREST
1: types.INTERP_LINEAR, # cv2.INTER_LINEAR
2: types.INTERP_CUBIC, # cv2.INTER_CUBIC
4: types.INTERP_LANCZOS3, # XXX use LANCZOS3 for cv2.INTER_LANCZOS4
}
output_dtype = (types.FLOAT16 if 'AMP' in config and
config.AMP.get("use_pure_fp16", False)
else types.FLOAT)
assert interp in interp_map, "interpolation method not supported by DALI"
interp = interp_map[interp]
pad_output = False
image_shape = config.get("image_shape", None)
if image_shape and image_shape[0] == 4:
pad_output = True
transforms = {
k: v
for d in dataset_config["transforms"] for k, v in d.items()
}
scale = transforms["NormalizeImage"].get("scale", 1.0 / 255)
if isinstance(scale, str):
scale = eval(scale)
mean = transforms["NormalizeImage"].get("mean", [0.485, 0.456, 0.406])
std = transforms["NormalizeImage"].get("std", [0.229, 0.224, 0.225])
mean = [v / scale for v in mean]
std = [v / scale for v in std]
if mode == "train":
resize_shorter = 256
crop = transforms["RandCropImage"]["size"]
scale = transforms["RandCropImage"].get("scale", [0.08, 1.])
ratio = transforms["RandCropImage"].get("ratio", [3.0 / 4, 4.0 / 3])
min_area = scale[0]
lower = ratio[0]
upper = ratio[1]
if 'PADDLE_TRAINER_ID' in env and 'PADDLE_TRAINERS_NUM' in env:
shard_id = int(env['PADDLE_TRAINER_ID'])
num_shards = int(env['PADDLE_TRAINERS_NUM'])
device_id = int(env['FLAGS_selected_gpus'])
pipe = HybridTrainPipe(
file_root,
file_list,
batch_size,
resize_shorter,
crop,
min_area,
lower,
upper,
interp,
mean,
std,
device_id,
shard_id,
num_shards,
seed=42 + shard_id,
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build()
pipelines = [pipe]
sample_per_shard = len(pipe) // num_shards
else:
pipelines = []
places = fluid.framework.cuda_places()
num_shards = len(places)
for idx, p in enumerate(places):
place = fluid.core.Place()
place.set_place(p)
device_id = place.gpu_device_id()
pipe = HybridTrainPipe(
file_root,
file_list,
batch_size,
resize_shorter,
crop,
min_area,
lower,
upper,
interp,
mean,
std,
device_id,
idx,
num_shards,
seed=42 + idx,
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build()
pipelines.append(pipe)
sample_per_shard = len(pipelines[0])
return DALIGenericIterator(
pipelines, ['feed_image', 'feed_label'], size=sample_per_shard)
else:
resize_shorter = transforms["ResizeImage"].get("resize_short", 256)
crop = transforms["CropImage"]["size"]
p = fluid.framework.cuda_places()[0]
place = fluid.core.Place()
place.set_place(p)
device_id = place.gpu_device_id()
pipe = HybridValPipe(
file_root,
file_list,
batch_size,
resize_shorter,
crop,
interp,
mean,
std,
device_id=device_id,
pad_output=pad_output,
output_dtype=output_dtype)
pipe.build()
return DALIGenericIterator(
pipe, ['feed_image', 'feed_label'],
size=len(pipe),
dynamic_shape=True,
fill_last_batch=True,
last_batch_padded=True)
def train(config):
return build(config, 'train')
def val(config):
return build(config, 'valid')
def _to_Tensor(lod_tensor, dtype):
data_tensor = fluid.layers.create_tensor(dtype=dtype)
data = np.array(lod_tensor).astype(dtype)
fluid.layers.assign(data, data_tensor)
return data_tensor
def normalize(feeds, config):
image, label = feeds['image'], feeds['label']
img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
image = fluid.layers.cast(image, 'float32')
costant = fluid.layers.fill_constant(
shape=[1], value=255.0, dtype='float32')
image = fluid.layers.elementwise_div(image, costant)
mean = fluid.layers.create_tensor(dtype="float32")
fluid.layers.assign(input=img_mean.astype("float32"), output=mean)
std = fluid.layers.create_tensor(dtype="float32")
fluid.layers.assign(input=img_std.astype("float32"), output=std)
image = fluid.layers.elementwise_sub(image, mean)
image = fluid.layers.elementwise_div(image, std)
image.stop_gradient = True
feeds['image'] = image
return feeds
def mix(feeds, config, is_train=True):
env = os.environ
gpu_num = paddle.fluid.core.get_cuda_device_count() if (
'PADDLE_TRAINERS_NUM') and (
'PADDLE_TRAINER_ID'
) not in env else int(env.get('PADDLE_TRAINERS_NUM', 0))
batch_size = config.TRAIN.batch_size // gpu_num
images = feeds['image']
label = feeds['label']
# TODO: hard code here, should be fixed!
alpha = 0.2
idx = _to_Tensor(np.random.permutation(batch_size), 'int32')
lam = np.random.beta(alpha, alpha)
images = lam * images + (1 - lam) * paddle.fluid.layers.gather(images, idx)
feed = {
'image': images,
'feed_y_a': label,
'feed_y_b': paddle.fluid.layers.gather(label, idx),
'feed_lam': _to_Tensor([lam] * batch_size, 'float32')
}
return feed if is_train else feeds
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
from collections import OrderedDict
from ppcls.optimizer import OptimizerBuilder
import paddle
import paddle.nn.functional as F
from ppcls.optimizer.learning_rate import LearningRateBuilder
from ppcls.arch import backbone
from ppcls.arch.loss import CELoss
from ppcls.arch.loss import MixCELoss
from ppcls.arch.loss import JSDivLoss
from ppcls.arch.loss import GoogLeNetLoss
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger, profiler
from paddle.distributed import fleet
from paddle.distributed.fleet import DistributedStrategy
def create_feeds(image_shape, use_mix=None, use_dali=None, dtype="float32"):
"""
Create feeds as model input
Args:
image_shape(list[int]): model input shape, such as [3, 224, 224]
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
feeds(dict): dict of model input variables
"""
feeds = OrderedDict()
feeds['image'] = paddle.static.data(
name="feed_image", shape=[None] + image_shape, dtype=dtype)
if use_mix and not use_dali:
feeds['feed_y_a'] = paddle.static.data(
name="feed_y_a", shape=[None, 1], dtype="int64")
feeds['feed_y_b'] = paddle.static.data(
name="feed_y_b", shape=[None, 1], dtype="int64")
feeds['feed_lam'] = paddle.static.data(
name="feed_lam", shape=[None, 1], dtype=dtype)
else:
feeds['label'] = paddle.static.data(
name="feed_label", shape=[None, 1], dtype="int64")
return feeds
def create_model(architecture, image, classes_num, config, is_train):
"""
Create a model
Args:
architecture(dict): architecture information,
name(such as ResNet50) is needed
image(variable): model input variable
classes_num(int): num of classes
config(dict): model config
Returns:
out(variable): model output variable
"""
name = architecture["name"]
params = architecture.get("params", {})
if "data_format" in config:
params["data_format"] = config["data_format"]
data_format = config["data_format"]
input_image_channel = config.get('image_shape', [3, 224, 224])[0]
if input_image_channel != 3:
logger.warning(
"Input image channel is changed to {}, maybe for better speed-up".
format(input_image_channel))
params["input_image_channel"] = input_image_channel
if "is_test" in params:
params['is_test'] = not is_train
model = backbone.__dict__[name](class_dim=classes_num, **params)
out = model(image)
return out
def create_loss(out,
feeds,
architecture,
classes_num=1000,
epsilon=None,
use_mix=False,
use_distillation=False):
"""
Create a loss for optimization, such as:
1. CrossEnotry loss
2. CrossEnotry loss with label smoothing
3. CrossEnotry loss with mix(mixup, cutmix, fmix)
4. CrossEnotry loss with label smoothing and (mixup, cutmix, fmix)
5. GoogLeNet loss
Args:
out(variable): model output variable
feeds(dict): dict of model input variables
architecture(dict): architecture information,
name(such as ResNet50) is needed
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
Returns:
loss(variable): loss variable
"""
if use_mix:
feed_y_a = paddle.reshape(feeds['feed_y_a'], [-1, 1])
feed_y_b = paddle.reshape(feeds['feed_y_b'], [-1, 1])
feed_lam = paddle.reshape(feeds['feed_lam'], [-1, 1])
else:
target = paddle.reshape(feeds['label'], [-1, 1])
if architecture["name"] == "GoogLeNet":
assert len(out) == 3, "GoogLeNet should have 3 outputs"
loss = GoogLeNetLoss(class_dim=classes_num, epsilon=epsilon)
return loss(out[0], out[1], out[2], target)
if use_distillation:
assert len(out) == 2, ("distillation output length must be 2, "
"but got {}".format(len(out)))
loss = JSDivLoss(class_dim=classes_num, epsilon=epsilon)
return loss(out[1], out[0])
if use_mix:
loss = MixCELoss(class_dim=classes_num, epsilon=epsilon)
return loss(out, feed_y_a, feed_y_b, feed_lam)
else:
loss = CELoss(class_dim=classes_num, epsilon=epsilon)
return loss(out, target)
def create_metric(out,
feeds,
architecture,
topk=5,
classes_num=1000,
config=None,
use_distillation=False):
"""
Create measures of model accuracy, such as top1 and top5
Args:
out(variable): model output variable
feeds(dict): dict of model input variables(included label)
topk(int): usually top5
classes_num(int): num of classes
config(dict) : model config
Returns:
fetchs(dict): dict of measures
"""
label = paddle.reshape(feeds['label'], [-1, 1])
if architecture["name"] == "GoogLeNet":
assert len(out) == 3, "GoogLeNet should have 3 outputs"
out = out[0]
else:
# just need student label to get metrics
if use_distillation:
out = out[1]
softmax_out = F.softmax(out)
fetchs = OrderedDict()
# set top1 to fetchs
top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
fetchs['top1'] = (top1, AverageMeter('top1', '.4f', need_avg=True))
# set topk to fetchs
k = min(topk, classes_num)
topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
topk_name = 'top{}'.format(k)
fetchs[topk_name] = (topk, AverageMeter(topk_name, '.4f', need_avg=True))
return fetchs
def create_fetchs(out,
feeds,
architecture,
topk=5,
classes_num=1000,
epsilon=None,
use_mix=False,
config=None,
use_distillation=False):
"""
Create fetchs as model outputs(included loss and measures),
will call create_loss and create_metric(if use_mix).
Args:
out(variable): model output variable
feeds(dict): dict of model input variables.
If use mix_up, it will not include label.
architecture(dict): architecture information,
name(such as ResNet50) is needed
topk(int): usually top5
classes_num(int): num of classes
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
config(dict): model config
Returns:
fetchs(dict): dict of model outputs(included loss and measures)
"""
fetchs = OrderedDict()
loss = create_loss(out, feeds, architecture, classes_num, epsilon, use_mix,
use_distillation)
fetchs['loss'] = (loss, AverageMeter('loss', '7.4f', need_avg=True))
if not use_mix:
metric = create_metric(out, feeds, architecture, topk, classes_num,
config, use_distillation)
fetchs.update(metric)
return fetchs
def create_optimizer(config):
"""
Create an optimizer using config, usually including
learning rate and regularization.
Args:
config(dict): such as
{
'LEARNING_RATE':
{'function': 'Cosine',
'params': {'lr': 0.1}
},
'OPTIMIZER':
{'function': 'Momentum',
'params':{'momentum': 0.9},
'regularizer':
{'function': 'L2', 'factor': 0.0001}
}
}
Returns:
an optimizer instance
"""
# create learning_rate instance
lr_config = config['LEARNING_RATE']
lr_config['params'].update({
'epochs': config['epochs'],
'step_each_epoch':
config['total_images'] // config['TRAIN']['batch_size'],
})
lr = LearningRateBuilder(**lr_config)()
# create optimizer instance
opt_config = config['OPTIMIZER']
opt = OptimizerBuilder(**opt_config)
return opt(lr), lr
def create_strategy(config):
"""
Create build strategy and exec strategy.
Args:
config(dict): config
Returns:
build_strategy: build strategy
exec_strategy: exec strategy
"""
build_strategy = paddle.static.BuildStrategy()
exec_strategy = paddle.static.ExecutionStrategy()
exec_strategy.num_threads = 1
exec_strategy.num_iteration_per_drop_scope = (
10000
if 'AMP' in config and config.AMP.get("use_pure_fp16", False) else 10)
fuse_op = True if 'AMP' in config else False
fuse_bn_act_ops = config.get('fuse_bn_act_ops', fuse_op)
fuse_elewise_add_act_ops = config.get('fuse_elewise_add_act_ops', fuse_op)
fuse_bn_add_act_ops = config.get('fuse_bn_add_act_ops', fuse_op)
enable_addto = config.get('enable_addto', fuse_op)
try:
build_strategy.fuse_bn_act_ops = fuse_bn_act_ops
except Exception as e:
logger.info(
"PaddlePaddle version 1.7.0 or higher is "
"required when you want to fuse batch_norm and activation_op.")
try:
build_strategy.fuse_elewise_add_act_ops = fuse_elewise_add_act_ops
except Exception as e:
logger.info(
"PaddlePaddle version 1.7.0 or higher is "
"required when you want to fuse elewise_add_act and activation_op.")
try:
build_strategy.fuse_bn_add_act_ops = fuse_bn_add_act_ops
except Exception as e:
logger.info(
"PaddlePaddle 2.0-rc or higher is "
"required when you want to enable fuse_bn_add_act_ops strategy.")
try:
build_strategy.enable_addto = enable_addto
except Exception as e:
logger.info("PaddlePaddle 2.0-rc or higher is "
"required when you want to enable addto strategy.")
return build_strategy, exec_strategy
def dist_optimizer(config, optimizer):
"""
Create a distributed optimizer based on a normal optimizer
Args:
config(dict):
optimizer(): a normal optimizer
Returns:
optimizer: a distributed optimizer
"""
build_strategy, exec_strategy = create_strategy(config)
dist_strategy = DistributedStrategy()
dist_strategy.execution_strategy = exec_strategy
dist_strategy.build_strategy = build_strategy
dist_strategy.nccl_comm_num = 1
dist_strategy.fuse_all_reduce_ops = True
dist_strategy.fuse_grad_size_in_MB = 16
optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
return optimizer
def mixed_precision_optimizer(config, optimizer):
if 'AMP' in config:
amp_cfg = config.AMP if config.AMP else dict()
scale_loss = amp_cfg.get('scale_loss', 1.0)
use_dynamic_loss_scaling = amp_cfg.get('use_dynamic_loss_scaling',
False)
use_pure_fp16 = amp_cfg.get('use_pure_fp16', False)
optimizer = paddle.static.amp.decorate(
optimizer,
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling,
use_pure_fp16=use_pure_fp16,
use_fp16_guard=True)
return optimizer
def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
"""
Build a program using a model and an optimizer
1. create feeds
2. create a dataloader
3. create a model
4. create fetchs
5. create an optimizer
Args:
config(dict): config
main_prog(): main program
startup_prog(): startup program
is_train(bool): train or valid
is_distributed(bool): whether to use distributed training method
Returns:
dataloader(): a bridge between the model and the data
fetchs(dict): dict of model outputs(included loss and measures)
"""
with paddle.static.program_guard(main_prog, startup_prog):
with paddle.utils.unique_name.guard():
use_mix = config.get('use_mix') and is_train
use_dali = config.get('use_dali', False)
use_distillation = config.get('use_distillation')
feeds = create_feeds(
config.image_shape,
use_mix=use_mix,
use_dali=use_dali,
dtype="float32")
if use_dali and use_mix:
import dali
feeds = dali.mix(feeds, config, is_train)
out = create_model(config.ARCHITECTURE, feeds['image'],
config.classes_num, config, is_train)
fetchs = create_fetchs(
out,
feeds,
config.ARCHITECTURE,
config.topk,
config.classes_num,
epsilon=config.get('ls_epsilon'),
use_mix=use_mix,
config=config,
use_distillation=use_distillation)
lr_scheduler = None
optimizer = None
if is_train:
optimizer, lr_scheduler = create_optimizer(config)
optimizer = mixed_precision_optimizer(config, optimizer)
if is_distributed:
optimizer = dist_optimizer(config, optimizer)
optimizer.minimize(fetchs['loss'][0])
return fetchs, lr_scheduler, feeds, optimizer
def compile(config, program, loss_name=None, share_prog=None):
"""
Compile the program
Args:
config(dict): config
program(): the program which is wrapped by
loss_name(str): loss name
share_prog(): the shared program, used for evaluation during training
Returns:
compiled_program(): a compiled program
"""
build_strategy, exec_strategy = create_strategy(config)
compiled_program = paddle.static.CompiledProgram(
program).with_data_parallel(
share_vars_from=share_prog,
loss_name=loss_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
return compiled_program
total_step = 0
def run(dataloader,
exe,
program,
feeds,
fetchs,
epoch=0,
mode='train',
config=None,
vdl_writer=None,
lr_scheduler=None,
profiler_options=None):
"""
Feed data to the model and fetch the measures and loss
Args:
dataloader(paddle io dataloader):
exe():
program():
fetchs(dict): dict of measures and the loss
epoch(int): epoch of training or validation
model(str): log only
Returns:
"""
fetch_list = [f[0] for f in fetchs.values()]
metric_list = [
("lr", AverageMeter(
'lr', 'f', postfix=",", need_avg=False)),
("batch_time", AverageMeter(
'batch_cost', '.5f', postfix=" s,")),
("reader_time", AverageMeter(
'reader_cost', '.5f', postfix=" s,")),
]
topk_name = 'top{}'.format(config.topk)
metric_list.insert(0, ("loss", fetchs["loss"][1]))
use_mix = config.get("use_mix", False) and mode == "train"
if not use_mix:
metric_list.insert(0, (topk_name, fetchs[topk_name][1]))
metric_list.insert(0, ("top1", fetchs["top1"][1]))
metric_list = OrderedDict(metric_list)
for m in metric_list.values():
m.reset()
use_dali = config.get('use_dali', False)
dataloader = dataloader if use_dali else dataloader()
tic = time.time()
idx = 0
batch_size = None
while True:
# The DALI maybe raise RuntimeError for some particular images, such as ImageNet1k/n04418357_26036.JPEG
try:
batch = next(dataloader)
except StopIteration:
break
except RuntimeError:
logger.warning(
"Except RuntimeError when reading data from dataloader, try to read once again..."
)
continue
idx += 1
# ignore the warmup iters
if idx == 5:
metric_list["batch_time"].reset()
metric_list["reader_time"].reset()
metric_list['reader_time'].update(time.time() - tic)
profiler.add_profiler_step(profiler_options)
if use_dali:
batch_size = batch[0]["feed_image"].shape()[0]
feed_dict = batch[0]
else:
batch_size = batch[0].shape()[0]
feed_dict = {
key.name: batch[idx]
for idx, key in enumerate(feeds.values())
}
metrics = exe.run(program=program,
feed=feed_dict,
fetch_list=fetch_list)
for name, m in zip(fetchs.keys(), metrics):
metric_list[name].update(np.mean(m), batch_size)
metric_list["batch_time"].update(time.time() - tic)
if mode == "train":
metric_list['lr'].update(lr_scheduler.get_lr())
fetchs_str = ' '.join([
str(metric_list[key].mean)
if "time" in key else str(metric_list[key].value)
for key in metric_list
])
ips_info = " ips: {:.5f} images/sec.".format(
batch_size / metric_list["batch_time"].avg)
fetchs_str += ips_info
if lr_scheduler is not None:
if lr_scheduler.update_specified:
curr_global_counter = lr_scheduler.step_each_epoch * epoch + idx
update = max(
0, curr_global_counter - lr_scheduler.
update_start_step) % lr_scheduler.update_step_interval == 0
if update:
lr_scheduler.step()
else:
lr_scheduler.step()
if vdl_writer:
global total_step
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
total_step += 1
if mode == 'valid':
if idx % config.get('print_interval', 10) == 0:
logger.info("{:s} step:{:<4d} {:s}".format(mode, idx,
fetchs_str))
else:
epoch_str = "epoch:{:<3d}".format(epoch)
step_str = "{:s} step:{:<4d}".format(mode, idx)
if idx % config.get('print_interval', 10) == 0:
logger.info("{:s} {:s} {:s}".format(
logger.coloring(epoch_str, "HEADER")
if idx == 0 else epoch_str,
logger.coloring(step_str, "PURPLE"),
logger.coloring(fetchs_str, 'OKGREEN')))
tic = time.time()
end_str = ' '.join([str(m.mean) for m in metric_list.values()] +
[metric_list["batch_time"].total])
ips_info = "ips: {:.5f} images/sec.".format(
batch_size * metric_list["batch_time"].count /
metric_list["batch_time"].sum)
if mode == 'valid':
logger.info("END {:s} {:s} {:s}".format(mode, end_str, ips_info))
else:
end_epoch_str = "END epoch:{:<3d}".format(epoch)
logger.info("{:s} {:s} {:s} {:s}".format(end_epoch_str, mode, end_str,
ips_info))
if use_dali:
dataloader.reset()
# return top1_acc in order to save the best model
if mode == 'valid':
return fetchs["top1"][1].avg
#!/usr/bin/env bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export FLAGS_fraction_of_gpu_memory_to_use=0.80
python3.7 -m paddle.distributed.launch \
--gpus="0,1,2,3" \
tools/static/train.py \
-c ./configs/ResNet/ResNet50.yaml \
-o print_interval=10 \
-o use_dali=True
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import errno
import os
import re
import shutil
import tempfile
import paddle
from ppcls.utils import logger
__all__ = ['init_model', 'save_model']
def _mkdir_if_not_exist(path):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
'be happy if some process has already created {}'.format(
path))
else:
raise OSError('Failed to mkdir {}'.format(path))
def _load_state(path):
if os.path.exists(path + '.pdopt'):
# XXX another hack to ignore the optimizer state
tmp = tempfile.mkdtemp()
dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))
shutil.copy(path + '.pdparams', dst + '.pdparams')
state = paddle.static.load_program_state(dst)
shutil.rmtree(tmp)
else:
state = paddle.static.load_program_state(path)
return state
def load_params(exe, prog, path, ignore_params=None):
"""
Load model from the given path.
Args:
exe (fluid.Executor): The fluid.Executor object.
prog (fluid.Program): load weight to which Program object.
path (string): URL string or loca model path.
ignore_params (list): ignore variable to load when finetuning.
It can be specified by finetune_exclude_pretrained_params
and the usage can refer to the document
docs/advanced_tutorials/TRANSFER_LEARNING.md
"""
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
logger.info(
logger.coloring('Loading parameters from {}...'.format(path),
'HEADER'))
ignore_set = set()
state = _load_state(path)
# ignore the parameter which mismatch the shape
# between the model and pretrain weight.
all_var_shape = {}
for block in prog.blocks:
for param in block.all_parameters():
all_var_shape[param.name] = param.shape
ignore_set.update([
name for name, shape in all_var_shape.items()
if name in state and shape != state[name].shape
])
if ignore_params:
all_var_names = [var.name for var in prog.list_vars()]
ignore_list = filter(
lambda var: any([re.match(name, var) for name in ignore_params]),
all_var_names)
ignore_set.update(list(ignore_list))
if len(ignore_set) > 0:
for k in ignore_set:
if k in state:
logger.warning(
'variable {} is already excluded automatically'.format(k))
del state[k]
paddle.static.set_program_state(prog, state)
def init_model(config, program, exe):
"""
load model from checkpoint or pretrained_model
"""
checkpoints = config.get('checkpoints')
if checkpoints:
paddle.static.load(program, checkpoints, exe)
logger.info(
logger.coloring("Finish initing model from {}".format(checkpoints),
"HEADER"))
return
pretrained_model = config.get('pretrained_model')
if pretrained_model:
if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model]
for pretrain in pretrained_model:
load_params(exe, program, pretrain)
logger.info(
logger.coloring("Finish initing model from {}".format(
pretrained_model), "HEADER"))
def save_model(program, model_path, epoch_id, prefix='ppcls'):
"""
save model to the target path
"""
model_path = os.path.join(model_path, str(epoch_id))
_mkdir_if_not_exist(model_path)
model_prefix = os.path.join(model_path, prefix)
paddle.static.save(program, model_prefix)
logger.info(
logger.coloring("Already save model in {}".format(model_path),
"HEADER"))
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
from sys import version_info
import paddle
from paddle.distributed import fleet
from ppcls.data import Reader
from ppcls.utils.config import get_config
from ppcls.utils import logger
from tools.static import program
from save_load import init_model, save_model
def parse_args():
parser = argparse.ArgumentParser("PaddleClas train script")
parser.add_argument(
'-c',
'--config',
type=str,
default='configs/ResNet/ResNet50.yaml',
help='config file path')
parser.add_argument(
'--vdl_dir',
type=str,
default=None,
help='VisualDL logging directory for image.')
parser.add_argument(
'-p',
'--profiler_options',
type=str,
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
parser.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
args = parser.parse_args()
return args
def main(args):
config = get_config(args.config, overrides=args.override, show=True)
if config.get("is_distributed", True):
fleet.init(is_collective=True)
# assign the place
use_gpu = config.get("use_gpu", True)
# amp related config
if 'AMP' in config:
AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_exhaustive_search': 1,
'FLAGS_conv_workspace_size_limit': 1500,
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
'FLAGS_max_inplace_grad_add': 8,
}
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1'
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
use_xpu = config.get("use_xpu", False)
assert (
use_gpu and use_xpu
) is not True, "gpu and xpu can not be true in the same time in static mode!"
if use_gpu:
place = paddle.set_device('gpu')
elif use_xpu:
place = paddle.set_device('xpu')
else:
place = paddle.set_device('cpu')
# startup_prog is used to do some parameter init work,
# and train prog is used to hold the network
startup_prog = paddle.static.Program()
train_prog = paddle.static.Program()
best_top1_acc = 0.0 # best top1 acc record
train_fetchs, lr_scheduler, train_feeds, optimizer = program.build(
config,
train_prog,
startup_prog,
is_train=True,
is_distributed=config.get("is_distributed", True))
if config.validate:
valid_prog = paddle.static.Program()
valid_fetchs, _, valid_feeds, _ = program.build(
config,
valid_prog,
startup_prog,
is_train=False,
is_distributed=config.get("is_distributed", True))
# clone to prune some content which is irrelevant in valid_prog
valid_prog = valid_prog.clone(for_test=True)
# create the "Executor" with the statement of which place
exe = paddle.static.Executor(place)
# Parameter initialization
exe.run(startup_prog)
# load pretrained models or checkpoints
init_model(config, train_prog, exe)
if 'AMP' in config and config.AMP.get("use_pure_fp16", False):
optimizer.amp_init(
place,
scope=paddle.static.global_scope(),
test_program=valid_prog if config.validate else None)
if not config.get("is_distributed", True):
compiled_train_prog = program.compile(
config, train_prog, loss_name=train_fetchs["loss"][0].name)
else:
compiled_train_prog = train_prog
if not config.get('use_dali', False):
train_dataloader = Reader(config, 'train', places=place)()
if config.validate and paddle.distributed.get_rank() == 0:
valid_dataloader = Reader(config, 'valid', places=place)()
compiled_valid_prog = program.compile(config, valid_prog)
else:
assert use_gpu is True, "DALI only support gpu, please set use_gpu to True!"
import dali
train_dataloader = dali.train(config)
if config.validate and paddle.distributed.get_rank() == 0:
valid_dataloader = dali.val(config)
compiled_valid_prog = program.compile(config, valid_prog)
vdl_writer = None
if args.vdl_dir:
if version_info.major == 2:
logger.info(
"visualdl is just supported for python3, so it is disabled in python2..."
)
else:
from visualdl import LogWriter
vdl_writer = LogWriter(args.vdl_dir)
for epoch_id in range(config.epochs):
# 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_feeds,
train_fetchs, epoch_id, 'train', config, vdl_writer,
lr_scheduler, args.profiler_options)
if paddle.distributed.get_rank() == 0:
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
top1_acc = program.run(valid_dataloader, exe,
compiled_valid_prog, valid_feeds,
valid_fetchs, epoch_id, 'valid', config)
if top1_acc > best_top1_acc:
best_top1_acc = top1_acc
message = "The best top1 acc {:.5f}, in epoch: {:d}".format(
best_top1_acc, epoch_id)
logger.info("{:s}".format(logger.coloring(message, "RED")))
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(train_prog, model_path, "best_model")
# 3. save the persistable model
if epoch_id % config.save_interval == 0:
model_path = os.path.join(config.model_save_dir,
config.ARCHITECTURE["name"])
save_model(train_prog, model_path, epoch_id)
if __name__ == '__main__':
paddle.enable_static()
args = parse_args()
main(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册