diff --git a/configs/det/det_mv3_db_amp.yml b/configs/det/det_mv3_db_amp.yml new file mode 100644 index 0000000000000000000000000000000000000000..772342a2d34dfb2ed2975b72970f811c9300c473 --- /dev/null +++ b/configs/det/det_mv3_db_amp.yml @@ -0,0 +1,135 @@ +Global: + use_gpu: true + epoch_num: 1200 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/db_mv3/ + save_epoch_step: 1200 + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] + cal_metric_during_train: False + pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_db/predicts_db.txt + +AMP: + scale_loss: 1024.0 + use_dynamic_loss_scaling: True + +Architecture: + model_type: det + algorithm: DB + Transform: + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + Neck: + name: DBFPN + out_channels: 256 + Head: + name: DBHead + k: 50 + +Loss: + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + learning_rate: 0.001 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: DBPostProcess + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [-10, 10] } } + - { 'type': Resize, 'args': { 'size': [0.5, 3] } } + - EastRandomCropData: + size: [640, 640] + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 16 + num_workers: 8 + use_shared_memory: False + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: + image_shape: [736, 1280] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 8 + use_shared_memory: False diff --git a/tools/program.py b/tools/program.py index 798e6dff297ad1149942488cca1d5540f1924867..5963016b66d4d254e180b0bc9ad098e49518969f 100755 --- a/tools/program.py +++ b/tools/program.py @@ -226,14 +226,29 @@ def train(config, images = batch[0] if use_srn: model_average = True - if model_type == 'table' or extra_input: - preds = model(images, data=batch[1:]) + + # use amp + if scaler: + with paddle.amp.auto_cast(): + if model_type == 'table' or extra_input: + preds = model(images, data=batch[1:]) + else: + preds = model(images) else: - preds = model(images) + if model_type == 'table' or extra_input: + preds = model(images, data=batch[1:]) + else: + preds = model(images) loss = loss_class(preds, batch) avg_loss = loss['loss'] - avg_loss.backward() - optimizer.step() + + if scaler: + scaled_avg_loss = scaler.scale(avg_loss) + scaled_avg_loss.backward() + scaler.minimize(optimizer, scaled_avg_loss) + else: + avg_loss.backward() + optimizer.step() optimizer.clear_grad() train_batch_cost += time.time() - batch_start diff --git a/tools/train.py b/tools/train.py index 05d295aa99718c25b94a123c23d08c2904fe8c6a..b34ac9790e4ff776a79e5e9d556d2dd0e020911d 100755 --- a/tools/train.py +++ b/tools/train.py @@ -102,6 +102,23 @@ def main(config, device, logger, vdl_writer): if valid_dataloader is not None: logger.info('valid dataloader has {} iters'.format( len(valid_dataloader))) + + use_amp = True if "AMP" in config else False + if use_amp: + AMP_RELATED_FLAGS_SETTING = { + 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, + 'FLAGS_max_inplace_grad_add': 8, + } + paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) + scale_loss = config["AMP"].get("scale_loss", 1.0) + use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling", + False) + scaler = paddle.amp.GradScaler( + init_loss_scaling=scale_loss, + use_dynamic_loss_scaling=use_dynamic_loss_scaling) + else: + scaler = None + # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class,