提交 32dc1c1c 编写于 作者: littletomatodonkey's avatar littletomatodonkey

improve dygraph model

上级 26289ce0
......@@ -31,12 +31,12 @@ def check_version():
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 2.0.0 or higher is required, " \
err = "PaddlePaddle version 1.8.0 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('2.0.0')
fluid.require_version('1.8.0')
except Exception:
logger.error(err)
sys.exit(1)
......
......@@ -64,14 +64,18 @@ def print_dict(d, delimiter=0):
placeholder = "-" * 60
for k, v in sorted(d.items()):
if isinstance(v, dict):
logger.info("{}{} : ".format(delimiter * " ", logger.coloring(k, "HEADER")))
logger.info("{}{} : ".format(delimiter * " ",
logger.coloring(k, "HEADER")))
print_dict(v, delimiter + 4)
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
logger.info("{}{} : ".format(delimiter * " ", logger.coloring(str(k),"HEADER")))
logger.info("{}{} : ".format(delimiter * " ",
logger.coloring(str(k), "HEADER")))
for value in v:
print_dict(value, delimiter + 4)
else:
logger.info("{}{} : {}".format(delimiter * " ", logger.coloring(k,"HEADER"), logger.coloring(v,"OKGREEN")))
logger.info("{}{} : {}".format(delimiter * " ",
logger.coloring(k, "HEADER"),
logger.coloring(v, "OKGREEN")))
if k.isupper():
logger.info(placeholder)
......@@ -138,7 +142,9 @@ def override(dl, ks, v):
override(dl[k], ks[1:], v)
else:
if len(ks) == 1:
assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
# assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl:
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
override(dl[ks[0]], ks[1:], v)
......
......@@ -35,8 +35,6 @@ from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.incubate.fleet.collective import fleet
from paddle.fluid.incubate.fleet.collective import DistributedStrategy
def create_dataloader():
......@@ -243,43 +241,6 @@ def create_optimizer(config, parameter_list=None):
return opt(lr, parameter_list)
def dist_optimizer(config, optimizer):
"""
Create a distributed optimizer based on a normal optimizer
Args:
config(dict):
optimizer(): a normal optimizer
Returns:
optimizer: a distributed optimizer
"""
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_threads = 3
exec_strategy.num_iteration_per_drop_scope = 10
dist_strategy = DistributedStrategy()
dist_strategy.nccl_comm_num = 1
dist_strategy.fuse_all_reduce_ops = True
dist_strategy.exec_strategy = exec_strategy
optimizer = fleet.distributed_optimizer(optimizer, strategy=dist_strategy)
return optimizer
def mixed_precision_optimizer(config, optimizer):
use_fp16 = config.get('use_fp16', False)
amp_scale_loss = config.get('amp_scale_loss', 1.0)
use_dynamic_loss_scaling = config.get('use_dynamic_loss_scaling', False)
if use_fp16:
optimizer = fluid.contrib.mixed_precision.decorate(
optimizer,
init_loss_scaling=amp_scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
return optimizer
def create_feeds(batch, use_mix):
image = batch[0]
if use_mix:
......@@ -307,26 +268,22 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
Returns:
"""
print_interval = config.get("print_interval", 10)
use_mix = config.get("use_mix", False) and mode == "train"
if use_mix:
metric_list = OrderedDict([
metric_list = [
("loss", AverageMeter('loss', '7.4f')),
("lr", AverageMeter(
'lr', 'f', need_avg=False)),
("batch_time", AverageMeter('elapse', '.3f')),
('reader_time', AverageMeter('reader', '.3f')),
])
else:
]
if not use_mix:
topk_name = 'top{}'.format(config.topk)
metric_list = OrderedDict([
("loss", AverageMeter('loss', '7.4f')),
("top1", AverageMeter('top1', '.4f')),
(topk_name, AverageMeter(topk_name, '.4f')),
("lr", AverageMeter(
'lr', 'f', need_avg=False)),
("batch_time", AverageMeter('elapse', '.3f')),
('reader_time', AverageMeter('reader', '.3f')),
])
metric_list.insert(1, (topk_name, AverageMeter(topk_name, '.4f')))
metric_list.insert(1, ("top1", AverageMeter("top1", '.4f')))
metric_list = OrderedDict(metric_list)
tic = time.time()
for idx, batch in enumerate(dataloader()):
......@@ -354,12 +311,14 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
tic = time.time()
fetchs_str = ' '.join([str(m.value) for m in metric_list.values()])
if idx % print_interval == 0:
if mode == 'eval':
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx,
fetchs_str))
else:
epoch_str = "epoch:{:<3d}".format(epoch)
step_str = "{:s} step:{:<4d}".format(mode, idx)
logger.info("{:s} {:s} {:s}s".format(
logger.coloring(epoch_str, "HEADER")
if idx == 0 else epoch_str,
......
......@@ -5,4 +5,5 @@ export PYTHONPATH=$PWD:$PYTHONPATH
python -m paddle.distributed.launch \
--selected_gpus="0,1,2,3" \
tools/train.py \
-c ./configs/ResNet/ResNet50.yaml
-c ./configs/ResNet/ResNet50_vd.yaml \
-o print_interval=10
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册