未验证 提交 c82274bb 编写于 作者: K Kaipeng Deng 提交者: GitHub

Eval in train (#2121)

* eval in train
上级 52438b30
...@@ -79,7 +79,8 @@ class LogPrinter(Callback): ...@@ -79,7 +79,8 @@ class LogPrinter(Callback):
def on_step_end(self, status): def on_step_end(self, status):
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if self.model.mode == 'train': mode = status['mode']
if mode == 'train':
epoch_id = status['epoch_id'] epoch_id = status['epoch_id']
step_id = status['step_id'] step_id = status['step_id']
steps_per_epoch = status['steps_per_epoch'] steps_per_epoch = status['steps_per_epoch']
...@@ -88,8 +89,8 @@ class LogPrinter(Callback): ...@@ -88,8 +89,8 @@ class LogPrinter(Callback):
data_time = status['data_time'] data_time = status['data_time']
epoches = self.model.cfg.epoch epoches = self.model.cfg.epoch
batch_size = self.model.cfg['{}Reader'.format( batch_size = self.model.cfg['{}Reader'.format(mode.capitalize(
self.model.mode.capitalize())]['batch_size'] ))]['batch_size']
logs = training_staus.log() logs = training_staus.log()
space_fmt = ':' + str(len(str(steps_per_epoch))) + 'd' space_fmt = ':' + str(len(str(steps_per_epoch))) + 'd'
...@@ -119,14 +120,15 @@ class LogPrinter(Callback): ...@@ -119,14 +120,15 @@ class LogPrinter(Callback):
dtime=str(data_time), dtime=str(data_time),
ips=ips) ips=ips)
logger.info(fmt) logger.info(fmt)
if self.model.mode == 'eval': if mode == 'eval':
step_id = status['step_id'] step_id = status['step_id']
if step_id % 100 == 0: if step_id % 100 == 0:
logger.info("Eval iter: {}".format(step_id)) logger.info("Eval iter: {}".format(step_id))
def on_epoch_end(self, status): def on_epoch_end(self, status):
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
if self.model.mode == 'eval': mode = status['mode']
if mode == 'eval':
sample_num = status['sample_num'] sample_num = status['sample_num']
cost_time = status['cost_time'] cost_time = status['cost_time']
logger.info('Total sample number: {}, averge FPS: {}'.format( logger.info('Total sample number: {}, averge FPS: {}'.format(
...@@ -147,8 +149,11 @@ class Checkpointer(Callback): ...@@ -147,8 +149,11 @@ class Checkpointer(Callback):
self.ema.update(self.model.model) self.ema.update(self.model.model)
def on_epoch_end(self, status): def on_epoch_end(self, status):
assert self.model.mode == 'train', \ # Checkpointer only performed during training
"Checkpointer can only be set during training" mode = status['mode']
if mode != 'train':
return
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0:
epoch_id = status['epoch_id'] epoch_id = status['epoch_id']
end_epoch = self.model.cfg.epoch end_epoch = self.model.cfg.epoch
......
...@@ -35,8 +35,7 @@ def init_parallel_env(): ...@@ -35,8 +35,7 @@ def init_parallel_env():
random.seed(local_seed) random.seed(local_seed)
np.random.seed(local_seed) np.random.seed(local_seed)
if ParallelEnv().nranks > 1: paddle.distributed.init_parallel_env()
paddle.distributed.init_parallel_env()
def set_random_seed(seed): def set_random_seed(seed):
......
...@@ -60,15 +60,19 @@ class Trainer(object): ...@@ -60,15 +60,19 @@ class Trainer(object):
slim = create(cfg.slim) slim = create(cfg.slim)
slim(self.model) slim(self.model)
if ParallelEnv().nranks > 1:
self.model = paddle.DataParallel(self.model)
# build data loader # build data loader
self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())] self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
# TestDataset build after user set images, skip loader creation here if self.mode == 'train':
if self.mode != 'test':
self.loader = create('{}Reader'.format(self.mode.capitalize()))( self.loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, cfg.worker_num) self.dataset, cfg.worker_num)
# EvalDataset build with BatchSampler to evaluate in single device
# TODO: multi-device evaluate
if self.mode == 'eval':
self._eval_batch_sampler = paddle.io.BatchSampler(
self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
self.loader = create('{}Reader'.format(self.mode.capitalize()))(
self.dataset, cfg.worker_num, self._eval_batch_sampler)
# TestDataset build after user set images, skip loader creation here
# build optimizer in train mode # build optimizer in train mode
if self.mode == 'train': if self.mode == 'train':
...@@ -77,6 +81,9 @@ class Trainer(object): ...@@ -77,6 +81,9 @@ class Trainer(object):
self.optimizer = create('OptimizerBuilder')(self.lr, self.optimizer = create('OptimizerBuilder')(self.lr,
self.model.parameters()) self.model.parameters())
self._nranks = ParallelEnv().nranks
self._local_rank = ParallelEnv().local_rank
self.status = {} self.status = {}
self.start_epoch = 0 self.start_epoch = 0
...@@ -103,21 +110,18 @@ class Trainer(object): ...@@ -103,21 +110,18 @@ class Trainer(object):
self._compose_callback = None self._compose_callback = None
def _init_metrics(self): def _init_metrics(self):
if self.mode == 'eval': if self.cfg.metric == 'COCO':
if self.cfg.metric == 'COCO': self._metrics = [COCOMetric(anno_file=self.dataset.get_anno())]
self._metrics = [COCOMetric(anno_file=self.dataset.get_anno())] elif self.cfg.metric == 'VOC':
elif self.cfg.metric == 'VOC': self._metrics = [
self._metrics = [ VOCMetric(
VOCMetric( anno_file=self.dataset.get_anno(),
anno_file=self.dataset.get_anno(), class_num=self.cfg.num_classes,
class_num=self.cfg.num_classes, map_type=self.cfg.map_type)
map_type=self.cfg.map_type) ]
]
else:
logger.warn("Metric not support for metric type {}".format(
self.cfg.metric))
self._metrics = []
else: else:
logger.warn("Metric not support for metric type {}".format(
self.cfg.metric))
self._metrics = [] self._metrics = []
def _reset_metrics(self): def _reset_metrics(self):
...@@ -154,14 +158,16 @@ class Trainer(object): ...@@ -154,14 +158,16 @@ class Trainer(object):
weight_type, weights)) weight_type, weights))
self._weights_loaded = True self._weights_loaded = True
def train(self): def train(self, validate=False):
assert self.mode == 'train', "Model not in 'train' mode" assert self.mode == 'train', "Model not in 'train' mode"
self.model.train()
# if no given weights loaded, load backbone pretrain weights as default # if no given weights loaded, load backbone pretrain weights as default
if not self._weights_loaded: if not self._weights_loaded:
self.load_weights(self.cfg.pretrain_weights) self.load_weights(self.cfg.pretrain_weights)
if self._nranks > 1:
model = paddle.DataParallel(self.model)
self.status.update({ self.status.update({
'epoch_id': self.start_epoch, 'epoch_id': self.start_epoch,
'step_id': 0, 'step_id': 0,
...@@ -175,9 +181,11 @@ class Trainer(object): ...@@ -175,9 +181,11 @@ class Trainer(object):
self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter) self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
for epoch_id in range(self.start_epoch, self.cfg.epoch): for epoch_id in range(self.start_epoch, self.cfg.epoch):
self.status['mode'] = 'train'
self.status['epoch_id'] = epoch_id self.status['epoch_id'] = epoch_id
self._compose_callback.on_epoch_begin(self.status) self._compose_callback.on_epoch_begin(self.status)
self.loader.dataset.set_epoch(epoch_id) self.loader.dataset.set_epoch(epoch_id)
model.train()
iter_tic = time.time() iter_tic = time.time()
for step_id, data in enumerate(self.loader): for step_id, data in enumerate(self.loader):
self.status['data_time'].update(time.time() - iter_tic) self.status['data_time'].update(time.time() - iter_tic)
...@@ -185,7 +193,7 @@ class Trainer(object): ...@@ -185,7 +193,7 @@ class Trainer(object):
self._compose_callback.on_step_begin(self.status) self._compose_callback.on_step_begin(self.status)
# model forward # model forward
outputs = self.model(data) outputs = model(data)
loss = outputs['loss'] loss = outputs['loss']
# model backward # model backward
...@@ -196,23 +204,42 @@ class Trainer(object): ...@@ -196,23 +204,42 @@ class Trainer(object):
self.optimizer.clear_grad() self.optimizer.clear_grad()
self.status['learning_rate'] = curr_lr self.status['learning_rate'] = curr_lr
if ParallelEnv().nranks < 2 or ParallelEnv().local_rank == 0: if self._nranks < 2 or self._local_rank == 0:
self.status['training_staus'].update(outputs) self.status['training_staus'].update(outputs)
self.status['batch_time'].update(time.time() - iter_tic) self.status['batch_time'].update(time.time() - iter_tic)
self._compose_callback.on_step_end(self.status) self._compose_callback.on_step_end(self.status)
iter_tic = time.time() iter_tic = time.time()
self._compose_callback.on_epoch_end(self.status) self._compose_callback.on_epoch_end(self.status)
def evaluate(self): if validate and (self._nranks < 2 or self._local_rank == 0) \
and (epoch_id % self.cfg.snapshot_epoch == 0 \
or epoch_id == self.end_epoch - 1):
if not hasattr(self, '_eval_loader'):
# build evaluation dataset and loader
self._eval_dataset = self.cfg.EvalDataset
self._eval_batch_sampler = \
paddle.io.BatchSampler(
self._eval_dataset,
batch_size=self.cfg.EvalReader['batch_size'])
self._eval_loader = create('EvalReader')(
self._eval_dataset,
self.cfg.worker_num,
batch_sampler=self._eval_batch_sampler)
with paddle.no_grad():
self._eval_with_loader(self._eval_loader)
def _eval_with_loader(self, loader):
sample_num = 0 sample_num = 0
tic = time.time() tic = time.time()
self._compose_callback.on_epoch_begin(self.status) self._compose_callback.on_epoch_begin(self.status)
for step_id, data in enumerate(self.loader): self.status['mode'] = 'eval'
self.model.eval()
for step_id, data in enumerate(loader):
self.status['step_id'] = step_id self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status) self._compose_callback.on_step_begin(self.status)
# forward # forward
self.model.eval()
outs = self.model(data) outs = self.model(data)
# update metrics # update metrics
...@@ -233,6 +260,9 @@ class Trainer(object): ...@@ -233,6 +260,9 @@ class Trainer(object):
# reset metric states for metric may performed multiple times # reset metric states for metric may performed multiple times
self._reset_metrics() self._reset_metrics()
def evaluate(self):
self._eval_with_loader(self.loader)
def predict(self, images, draw_threshold=0.5, output_dir='output'): def predict(self, images, draw_threshold=0.5, output_dir='output'):
self.dataset.set_images(images) self.dataset.set_images(images)
loader = create('TestReader')(self.dataset, 0) loader = create('TestReader')(self.dataset, 0)
...@@ -242,11 +272,12 @@ class Trainer(object): ...@@ -242,11 +272,12 @@ class Trainer(object):
anno_file = self.dataset.get_anno() anno_file = self.dataset.get_anno()
clsid2catid, catid2name = get_categories(self.cfg.metric, anno_file) clsid2catid, catid2name = get_categories(self.cfg.metric, anno_file)
# Run Infer # Run Infer
self.status['mode'] = 'test'
self.model.eval()
for step_id, data in enumerate(loader): for step_id, data in enumerate(loader):
self.status['step_id'] = step_id self.status['step_id'] = step_id
# forward # forward
self.model.eval()
outs = self.model(data) outs = self.model(data)
for key in ['im_shape', 'scale_factor', 'im_id']: for key in ['im_shape', 'scale_factor', 'im_id']:
outs[key] = data[key] outs[key] = data[key]
...@@ -301,6 +332,8 @@ class Trainer(object): ...@@ -301,6 +332,8 @@ class Trainer(object):
if image_shape is None: if image_shape is None:
image_shape = [3, None, None] image_shape = [3, None, None]
self.model.eval()
# Save infer cfg # Save infer cfg
_dump_infer_config(self.cfg, _dump_infer_config(self.cfg,
os.path.join(save_dir, 'infer_cfg.yml'), image_shape, os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
......
...@@ -51,6 +51,7 @@ packages = [ ...@@ -51,6 +51,7 @@ packages = [
'ppdet.core', 'ppdet.core',
'ppdet.data', 'ppdet.data',
'ppdet.engine', 'ppdet.engine',
'ppdet.metrics',
'ppdet.modeling', 'ppdet.modeling',
'ppdet.model_zoo', 'ppdet.model_zoo',
'ppdet.py_op', 'ppdet.py_op',
......
...@@ -32,7 +32,7 @@ from paddle.distributed import ParallelEnv ...@@ -32,7 +32,7 @@ from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config from ppdet.core.workspace import load_config, merge_config
from ppdet.utils.check import check_gpu, check_version, check_config from ppdet.utils.check import check_gpu, check_version, check_config
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.engine import Trainer from ppdet.engine import Trainer, init_parallel_env
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger('eval') logger = setup_logger('eval')
...@@ -60,6 +60,9 @@ def parse_args(): ...@@ -60,6 +60,9 @@ def parse_args():
def run(FLAGS, cfg): def run(FLAGS, cfg):
# init parallel environment if nranks > 1
init_parallel_env()
# build trainer # build trainer
trainer = Trainer(cfg, mode='eval') trainer = Trainer(cfg, mode='eval')
......
...@@ -84,7 +84,7 @@ def run(FLAGS, cfg): ...@@ -84,7 +84,7 @@ def run(FLAGS, cfg):
trainer.load_weights(cfg.pretrain_weights, FLAGS.weight_type) trainer.load_weights(cfg.pretrain_weights, FLAGS.weight_type)
# training # training
trainer.train() trainer.train(FLAGS.eval)
def main(): def main():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册