ch_PP-OCRv2_det_cml.yml 4.7 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10
Global:
  use_gpu: true
  epoch_num: 1200
  log_smooth_window: 20
  print_batch_step: 2
  save_model_dir: ./output/ch_db_mv3/
  save_epoch_step: 1200
  # evaluation is run every 5000 iterations after the 4000th iteration
  eval_batch_step: [3000, 2000]
  cal_metric_during_train: False
L
LDOUBLEV 已提交
11
  pretrained_model: ./pretrain_models/ch_PP-OCRv2_det_distill_train/best_accuracy
L
LDOUBLEV 已提交
12 13 14 15 16 17 18 19 20
  checkpoints:
  save_inference_dir:
  use_visualdl: False
  infer_img: doc/imgs_en/img_10.jpg
  save_res_path: ./output/det_db/predicts_db.txt

Architecture:
  name: DistillationModel
  algorithm: Distillation
A
andyjpaddle 已提交
21
  model_type: det
L
LDOUBLEV 已提交
22
  Models:
L
LDOUBLEV 已提交
23 24
    Teacher:
      freeze_params: true
L
LDOUBLEV 已提交
25 26 27
      return_all_feats: false
      model_type: det
      algorithm: DB
L
LDOUBLEV 已提交
28
      Transform:
L
LDOUBLEV 已提交
29
      Backbone:
L
LDOUBLEV 已提交
30 31
        name: ResNet
        layers: 18
L
LDOUBLEV 已提交
32 33
      Neck:
        name: DBFPN
L
LDOUBLEV 已提交
34
        out_channels: 256
L
LDOUBLEV 已提交
35 36 37
      Head:
        name: DBHead
        k: 50
L
LDOUBLEV 已提交
38
    Student:
L
LDOUBLEV 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
      freeze_params: false
      return_all_feats: false
      model_type: det
      algorithm: DB
      Backbone:
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: True
      Neck:
        name: DBFPN
        out_channels: 96
      Head:
        name: DBHead
        k: 50
L
LDOUBLEV 已提交
54 55
    Student2:
      freeze_params: false
L
LDOUBLEV 已提交
56 57 58 59 60
      return_all_feats: false
      model_type: det
      algorithm: DB
      Transform:
      Backbone:
L
LDOUBLEV 已提交
61 62 63 64
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: True
L
LDOUBLEV 已提交
65 66
      Neck:
        name: DBFPN
L
LDOUBLEV 已提交
67
        out_channels: 96
L
LDOUBLEV 已提交
68 69 70
      Head:
        name: DBHead
        k: 50
L
LDOUBLEV 已提交
71
    
L
LDOUBLEV 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
Loss:
  name: CombinedLoss
  loss_config_list:
  - DistillationDilaDBLoss:
      weight: 1.0
      model_name_pairs:
      - ["Student", "Teacher"]
      - ["Student2", "Teacher"]
      key: maps
      balance_loss: true
      main_loss_type: DiceLoss
      alpha: 5
      beta: 10
      ohem_ratio: 3
  - DistillationDMLLoss:
      model_name_pairs:
      - ["Student", "Student2"]
      maps_name: "thrink_maps"
      weight: 1.0
L
LDOUBLEV 已提交
91
      # act: None
L
LDOUBLEV 已提交
92 93 94 95 96 97
      model_name_pairs: ["Student", "Student2"]
      key: maps
  - DistillationDBLoss:
      weight: 1.0
      model_name_list: ["Student", "Student2"]
      # key: maps
L
LDOUBLEV 已提交
98
      # name: DBLoss
L
LDOUBLEV 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
      balance_loss: true
      main_loss_type: DiceLoss
      alpha: 5
      beta: 10
      ohem_ratio: 3


Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine
    learning_rate: 0.001
    warmup_epoch: 2
  regularizer:
    name: 'L2'
    factor: 0

PostProcess:
  name: DistillationDBPostProcess
L
LDOUBLEV 已提交
120 121
  model_name: ["Student", "Student2", "Teacher"]
  # key: maps
L
LDOUBLEV 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
  thresh: 0.3
  box_thresh: 0.6
  max_candidates: 1000
  unclip_ratio: 1.5

Metric:
  name: DistillationMetric
  base_metric_name: DetMetric
  main_indicator: hmean
  key: "Student"

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/icdar2015/text_localization/
    label_file_list:
      - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
    ratio_list: [1.0]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
L
LDOUBLEV 已提交
145
      - CopyPaste:
L
LDOUBLEV 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
      - IaaAugment:
          augmenter_args:
            - { 'type': Fliplr, 'args': { 'p': 0.5 } }
            - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
            - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
      - EastRandomCropData:
          size: [960, 960]
          max_tries: 50
          keep_ratio: true
      - MakeBorderMap:
          shrink_ratio: 0.4
          thresh_min: 0.3
          thresh_max: 0.7
      - MakeShrinkMap:
          shrink_ratio: 0.4
          min_text_size: 8
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
  loader:
    shuffle: True
    drop_last: False
    batch_size_per_card: 8
    num_workers: 4

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/icdar2015/text_localization/
    label_file_list:
      - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
      - DetResizeForTest:
#           image_shape: [736, 1280]
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 1 # must be 1
    num_workers: 2