未验证 提交 c82274bb 编写于 作者: K Kaipeng Deng 提交者: GitHub

Eval in train (#2121)

* eval in train
上级 52438b30
......@@ -79,7 +79,8 @@ class LogPrinter(Callback):
def on_step_end(self, status):
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if self.model.mode == 'train':
mode = status['mode']
if mode == 'train':
epoch_id = status['epoch_id']
step_id = status['step_id']
steps_per_epoch = status['steps_per_epoch']
......@@ -88,8 +89,8 @@ class LogPrinter(Callback):
data_time = status['data_time']
epoches = self.model.cfg.epoch
batch_size = self.model.cfg['{}Reader'.format(
self.model.mode.capitalize())]['batch_size']
batch_size = self.model.cfg['{}Reader'.format(mode.capitalize(
))]['batch_size']
logs = training_staus.log()
space_fmt = ':' + str(len(str(steps_per_epoch))) + 'd'
......@@ -119,14 +120,15 @@ class LogPrinter(Callback):
dtime=str(data_time),
ips=ips)
logger.info(fmt)
if self.model.mode == 'eval':
if mode == 'eval':
step_id = status['step_id']
if step_id % 100 == 0:
logger.info("Eval iter: {}".format(step_id))
def on_epoch_end(self, status):
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if self.model.mode == 'eval':
mode = status['mode']
if mode == 'eval':
sample_num = status['sample_num']
cost_time = status['cost_time']
logger.info('Total sample number: {}, averge FPS: {}'.format(
......@@ -147,8 +149,11 @@ class Checkpointer(Callback):
self.ema.update(self.model.model)
def on_epoch_end(self, status):
assert self.model.mode == 'train', \
"Checkpointer can only be set during training"
# Checkpointer only performed during training
mode = status['mode']
if mode != 'train':
return
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
epoch_id = status['epoch_id']
end_epoch = self.model.cfg.epoch
......
......@@ -35,7 +35,6 @@ def init_parallel_env():
random.seed(local_seed)
np.random.seed(local_seed)
if ParallelEnv().nranks > 1:
paddle.distributed.init_parallel_env()
......
......@@ -60,15 +60,19 @@ class Trainer(object):
slim = create(cfg.slim)
slim(self.model)
if ParallelEnv().nranks > 1:
self.model = paddle.DataParallel(self.model)
# build data loader
self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
# TestDataset build after user set images, skip loader creation here
if self.mode != 'test':
if self.mode == 'train':
self.loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, cfg.worker_num)
# EvalDataset build with BatchSampler to evaluate in single device
# TODO: multi-device evaluate
if self.mode == 'eval':
self._eval_batch_sampler = paddle.io.BatchSampler(
self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
self.loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, cfg.worker_num, self._eval_batch_sampler)
# TestDataset build after user set images, skip loader creation here
# build optimizer in train mode
if self.mode == 'train':
......@@ -77,6 +81,9 @@ class Trainer(object):
self.optimizer = create('OptimizerBuilder')(self.lr,
self.model.parameters())
self._nranks = ParallelEnv().nranks
self._local_rank = ParallelEnv().local_rank
self.status = {}
self.start_epoch = 0
......@@ -103,7 +110,6 @@ class Trainer(object):
self._compose_callback = None
def _init_metrics(self):
if self.mode == 'eval':
if self.cfg.metric == 'COCO':
self._metrics = [COCOMetric(anno_file=self.dataset.get_anno())]
elif self.cfg.metric == 'VOC':
......@@ -117,8 +123,6 @@ class Trainer(object):
logger.warn("Metric not support for metric type {}".format(
self.cfg.metric))
self._metrics = []
else:
self._metrics = []
def _reset_metrics(self):
for metric in self._metrics:
......@@ -154,14 +158,16 @@ class Trainer(object):
weight_type, weights))
self._weights_loaded = True
def train(self):
def train(self, validate=False):
assert self.mode == 'train', "Model not in 'train' mode"
self.model.train()
# if no given weights loaded, load backbone pretrain weights as default
if not self._weights_loaded:
self.load_weights(self.cfg.pretrain_weights)
if self._nranks > 1:
model = paddle.DataParallel(self.model)
self.status.update({
'epoch_id': self.start_epoch,
'step_id': 0,
......@@ -175,9 +181,11 @@ class Trainer(object):
self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
for epoch_id in range(self.start_epoch, self.cfg.epoch):
self.status['mode'] = 'train'
self.status['epoch_id'] = epoch_id
self._compose_callback.on_epoch_begin(self.status)
self.loader.dataset.set_epoch(epoch_id)
model.train()
iter_tic = time.time()
for step_id, data in enumerate(self.loader):
self.status['data_time'].update(time.time() - iter_tic)
......@@ -185,7 +193,7 @@ class Trainer(object):
self._compose_callback.on_step_begin(self.status)
# model forward
outputs = self.model(data)
outputs = model(data)
loss = outputs['loss']
# model backward
......@@ -196,23 +204,42 @@ class Trainer(object):
self.optimizer.clear_grad()
self.status['learning_rate'] = curr_lr
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if self._nranks < 2 or self._local_rank == 0:
self.status['training_staus'].update(outputs)
self.status['batch_time'].update(time.time() - iter_tic)
self._compose_callback.on_step_end(self.status)
iter_tic = time.time()
self._compose_callback.on_epoch_end(self.status)
def evaluate(self):
if validate and (self._nranks < 2 or self._local_rank == 0) \
and (epoch_id % self.cfg.snapshot_epoch == 0 \
or epoch_id == self.end_epoch - 1):
if not hasattr(self, '_eval_loader'):
# build evaluation dataset and loader
self._eval_dataset = self.cfg.EvalDataset
self._eval_batch_sampler = \
paddle.io.BatchSampler(
self._eval_dataset,
batch_size=self.cfg.EvalReader['batch_size'])
self._eval_loader = create('EvalReader')(
self._eval_dataset,
self.cfg.worker_num,
batch_sampler=self._eval_batch_sampler)
with paddle.no_grad():
self._eval_with_loader(self._eval_loader)
def _eval_with_loader(self, loader):
sample_num = 0
tic = time.time()
self._compose_callback.on_epoch_begin(self.status)
for step_id, data in enumerate(self.loader):
self.status['mode'] = 'eval'
self.model.eval()
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status)
# forward
self.model.eval()
outs = self.model(data)
# update metrics
......@@ -233,6 +260,9 @@ class Trainer(object):
# reset metric states for metric may performed multiple times
self._reset_metrics()
def evaluate(self):
self._eval_with_loader(self.loader)
def predict(self, images, draw_threshold=0.5, output_dir='output'):
self.dataset.set_images(images)
loader = create('TestReader')(self.dataset, 0)
......@@ -243,10 +273,11 @@ class Trainer(object):
clsid2catid, catid2name = get_categories(self.cfg.metric, anno_file)
# Run Infer
self.status['mode'] = 'test'
self.model.eval()
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id
# forward
self.model.eval()
outs = self.model(data)
for key in ['im_shape', 'scale_factor', 'im_id']:
outs[key] = data[key]
......@@ -301,6 +332,8 @@ class Trainer(object):
if image_shape is None:
image_shape = [3, None, None]
self.model.eval()
# Save infer cfg
_dump_infer_config(self.cfg,
os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
......
......@@ -51,6 +51,7 @@ packages = [
'ppdet.core',
'ppdet.data',
'ppdet.engine',
'ppdet.metrics',
'ppdet.modeling',
'ppdet.model_zoo',
'ppdet.py_op',
......
......@@ -32,7 +32,7 @@ from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config
from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser
from ppdet.engine import Trainer
from ppdet.engine import Trainer, init_parallel_env
from ppdet.utils.logger import setup_logger
logger = setup_logger('eval')
......@@ -60,6 +60,9 @@ def parse_args():
def run(FLAGS, cfg):
# init parallel environment if nranks > 1
init_parallel_env()
# build trainer
trainer = Trainer(cfg, mode='eval')
......
......@@ -84,7 +84,7 @@ def run(FLAGS, cfg):
trainer.load_weights(cfg.pretrain_weights, FLAGS.weight_type)
# training
trainer.train()
trainer.train(FLAGS.eval)
def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册