# -*- coding: utf-8 -*- # @Time : 2019/8/23 21:58 # @Author : zhoujun import time import paddle from tqdm import tqdm from base import BaseTrainer from utils import runningScore, cal_text_score, Polynomial, profiler class Trainer(BaseTrainer): def __init__(self, config, model, criterion, train_loader, validate_loader, metric_cls, post_process=None, profiler_options=None): super(Trainer, self).__init__(config, model, criterion, train_loader, validate_loader, metric_cls, post_process) self.profiler_options = profiler_options self.enable_eval = config['trainer'].get('enable_eval', True) def _train_epoch(self, epoch): self.model.train() total_samples = 0 train_reader_cost = 0.0 train_batch_cost = 0.0 reader_start = time.time() epoch_start = time.time() train_loss = 0. running_metric_text = runningScore(2) for i, batch in enumerate(self.train_loader): profiler.add_profiler_step(self.profiler_options) if i >= self.train_loader_len: break self.global_step += 1 lr = self.optimizer.get_lr() cur_batch_size = batch['img'].shape[0] train_reader_cost += time.time() - reader_start if self.amp: with paddle.amp.auto_cast( enable='gpu' in paddle.device.get_device(), custom_white_list=self.amp.get('custom_white_list', []), custom_black_list=self.amp.get('custom_black_list', []), level=self.amp.get('level', 'O2')): preds = self.model(batch['img']) loss_dict = self.criterion(preds.astype(paddle.float32), batch) scaled_loss = self.amp['scaler'].scale(loss_dict['loss']) scaled_loss.backward() self.amp['scaler'].minimize(self.optimizer, scaled_loss) else: preds = self.model(batch['img']) loss_dict = self.criterion(preds, batch) # backward loss_dict['loss'].backward() self.optimizer.step() self.lr_scheduler.step() self.optimizer.clear_grad() train_batch_time = time.time() - reader_start train_batch_cost += train_batch_time total_samples += cur_batch_size # acc iou score_shrink_map = cal_text_score( preds[:, 0, :, :], batch['shrink_map'], batch['shrink_mask'], running_metric_text, thred=self.config['post_processing']['args']['thresh']) # loss 和 acc 记录到日志 loss_str = 'loss: {:.4f}, '.format(loss_dict['loss'].item()) for idx, (key, value) in enumerate(loss_dict.items()): loss_dict[key] = value.item() if key == 'loss': continue loss_str += '{}: {:.4f}'.format(key, loss_dict[key]) if idx < len(loss_dict) - 1: loss_str += ', ' train_loss += loss_dict['loss'] acc = score_shrink_map['Mean Acc'] iou_shrink_map = score_shrink_map['Mean IoU'] if self.global_step % self.log_iter == 0: self.logger_info( '[{}/{}], [{}/{}], global_step: {}, ips: {:.1f} samples/sec, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, acc: {:.4f}, iou_shrink_map: {:.4f}, {}lr:{:.6}, time:{:.2f}'. format(epoch, self.epochs, i + 1, self.train_loader_len, self.global_step, total_samples / train_batch_cost, train_reader_cost / self.log_iter, train_batch_cost / self.log_iter, total_samples / self.log_iter, acc, iou_shrink_map, loss_str, lr, train_batch_cost)) total_samples = 0 train_reader_cost = 0.0 train_batch_cost = 0.0 if self.visualdl_enable and paddle.distributed.get_rank() == 0: # write tensorboard for key, value in loss_dict.items(): self.writer.add_scalar('TRAIN/LOSS/{}'.format(key), value, self.global_step) self.writer.add_scalar('TRAIN/ACC_IOU/acc', acc, self.global_step) self.writer.add_scalar('TRAIN/ACC_IOU/iou_shrink_map', iou_shrink_map, self.global_step) self.writer.add_scalar('TRAIN/lr', lr, self.global_step) reader_start = time.time() return { 'train_loss': train_loss / self.train_loader_len, 'lr': lr, 'time': time.time() - epoch_start, 'epoch': epoch } def _eval(self, epoch): self.model.eval() raw_metrics = [] total_frame = 0.0 total_time = 0.0 for i, batch in tqdm( enumerate(self.validate_loader), total=len(self.validate_loader), desc='test model'): with paddle.no_grad(): start = time.time() if self.amp: with paddle.amp.auto_cast( enable='gpu' in paddle.device.get_device(), custom_white_list=self.amp.get('custom_white_list', []), custom_black_list=self.amp.get('custom_black_list', []), level=self.amp.get('level', 'O2')): preds = self.model(batch['img']) preds = preds.astype(paddle.float32) else: preds = self.model(batch['img']) boxes, scores = self.post_process( batch, preds, is_output_polygon=self.metric_cls.is_output_polygon) total_frame += batch['img'].shape[0] total_time += time.time() - start raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores)) raw_metrics.append(raw_metric) metrics = self.metric_cls.gather_measure(raw_metrics) self.logger_info('FPS:{}'.format(total_frame / total_time)) return metrics['recall'].avg, metrics['precision'].avg, metrics[ 'fmeasure'].avg def _on_epoch_finish(self): self.logger_info('[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}'. format(self.epoch_result['epoch'], self.epochs, self. epoch_result['train_loss'], self.epoch_result[ 'time'], self.epoch_result['lr'])) net_save_path = '{}/model_latest.pth'.format(self.checkpoint_dir) net_save_path_best = '{}/model_best.pth'.format(self.checkpoint_dir) if paddle.distributed.get_rank() == 0: self._save_checkpoint(self.epoch_result['epoch'], net_save_path) save_best = False if self.validate_loader is not None and self.metric_cls is not None and self.enable_eval: # 使用f1作为最优模型指标 recall, precision, hmean = self._eval(self.epoch_result[ 'epoch']) if self.visualdl_enable: self.writer.add_scalar('EVAL/recall', recall, self.global_step) self.writer.add_scalar('EVAL/precision', precision, self.global_step) self.writer.add_scalar('EVAL/hmean', hmean, self.global_step) self.logger_info( 'test: recall: {:.6f}, precision: {:.6f}, hmean: {:.6f}'. format(recall, precision, hmean)) if hmean >= self.metrics['hmean']: save_best = True self.metrics['train_loss'] = self.epoch_result['train_loss'] self.metrics['hmean'] = hmean self.metrics['precision'] = precision self.metrics['recall'] = recall self.metrics['best_model_epoch'] = self.epoch_result[ 'epoch'] else: if self.epoch_result['train_loss'] <= self.metrics[ 'train_loss']: save_best = True self.metrics['train_loss'] = self.epoch_result['train_loss'] self.metrics['best_model_epoch'] = self.epoch_result[ 'epoch'] best_str = 'current best, ' for k, v in self.metrics.items(): best_str += '{}: {:.6f}, '.format(k, v) self.logger_info(best_str) if save_best: import shutil shutil.copy(net_save_path, net_save_path_best) self.logger_info("Saving current best: {}".format( net_save_path_best)) else: self.logger_info("Saving checkpoint: {}".format(net_save_path)) def _on_train_finish(self): if self.enable_eval: for k, v in self.metrics.items(): self.logger_info('{}:{}'.format(k, v)) self.logger_info('finish train') def _initialize_scheduler(self): if self.config['lr_scheduler']['type'] == 'Polynomial': self.config['lr_scheduler']['args']['epochs'] = self.config[ 'trainer']['epochs'] self.config['lr_scheduler']['args']['step_each_epoch'] = len( self.train_loader) self.lr_scheduler = Polynomial( **self.config['lr_scheduler']['args'])() else: self.lr_scheduler = self._initialize('lr_scheduler', paddle.optimizer.lr)