ch_PP-OCR_det_cml.yml 4.7 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10
Global:
  use_gpu: true
  epoch_num: 1200
  log_smooth_window: 20
  print_batch_step: 2
  save_model_dir: ./output/ch_db_mv3/
  save_epoch_step: 1200
  # evaluation is run every 5000 iterations after the 4000th iteration
  eval_batch_step: [3000, 2000]
  cal_metric_during_train: False
L
LDOUBLEV 已提交
11
  pretrained_model: ./pretrain_models/ch_PP-OCRv2_det_distill_train/best_accuracy
L
LDOUBLEV 已提交
12 13 14 15 16 17 18 19 20 21
  checkpoints:
  save_inference_dir:
  use_visualdl: False
  infer_img: doc/imgs_en/img_10.jpg
  save_res_path: ./output/det_db/predicts_db.txt

Architecture:
  name: DistillationModel
  algorithm: Distillation
  Models:
L
LDOUBLEV 已提交
22 23
    Teacher:
      freeze_params: true
L
LDOUBLEV 已提交
24 25 26
      return_all_feats: false
      model_type: det
      algorithm: DB
L
LDOUBLEV 已提交
27
      Transform:
L
LDOUBLEV 已提交
28
      Backbone:
L
LDOUBLEV 已提交
29 30
        name: ResNet
        layers: 18
L
LDOUBLEV 已提交
31 32
      Neck:
        name: DBFPN
L
LDOUBLEV 已提交
33
        out_channels: 256
L
LDOUBLEV 已提交
34 35 36
      Head:
        name: DBHead
        k: 50
L
LDOUBLEV 已提交
37
    Student:
L
LDOUBLEV 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
      freeze_params: false
      return_all_feats: false
      model_type: det
      algorithm: DB
      Backbone:
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: True
      Neck:
        name: DBFPN
        out_channels: 96
      Head:
        name: DBHead
        k: 50
L
LDOUBLEV 已提交
53 54
    Student2:
      freeze_params: false
L
LDOUBLEV 已提交
55 56 57 58 59
      return_all_feats: false
      model_type: det
      algorithm: DB
      Transform:
      Backbone:
L
LDOUBLEV 已提交
60 61 62 63
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: True
L
LDOUBLEV 已提交
64 65
      Neck:
        name: DBFPN
L
LDOUBLEV 已提交
66
        out_channels: 96
L
LDOUBLEV 已提交
67 68 69
      Head:
        name: DBHead
        k: 50
L
LDOUBLEV 已提交
70
    
L
LDOUBLEV 已提交
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
Loss:
  name: CombinedLoss
  loss_config_list:
  - DistillationDilaDBLoss:
      weight: 1.0
      model_name_pairs:
      - ["Student", "Teacher"]
      - ["Student2", "Teacher"]
      key: maps
      balance_loss: true
      main_loss_type: DiceLoss
      alpha: 5
      beta: 10
      ohem_ratio: 3
  - DistillationDMLLoss:
      model_name_pairs:
      - ["Student", "Student2"]
      maps_name: "thrink_maps"
      weight: 1.0
L
LDOUBLEV 已提交
90
      # act: None
L
LDOUBLEV 已提交
91 92 93 94 95 96
      model_name_pairs: ["Student", "Student2"]
      key: maps
  - DistillationDBLoss:
      weight: 1.0
      model_name_list: ["Student", "Student2"]
      # key: maps
L
LDOUBLEV 已提交
97
      # name: DBLoss
L
LDOUBLEV 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
      balance_loss: true
      main_loss_type: DiceLoss
      alpha: 5
      beta: 10
      ohem_ratio: 3


Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine
    learning_rate: 0.001
    warmup_epoch: 2
  regularizer:
    name: 'L2'
    factor: 0

PostProcess:
  name: DistillationDBPostProcess
L
LDOUBLEV 已提交
119 120
  model_name: ["Student", "Student2", "Teacher"]
  # key: maps
L
LDOUBLEV 已提交
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
  thresh: 0.3
  box_thresh: 0.6
  max_candidates: 1000
  unclip_ratio: 1.5

Metric:
  name: DistillationMetric
  base_metric_name: DetMetric
  main_indicator: hmean
  key: "Student"

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/icdar2015/text_localization/
    label_file_list:
      - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
    ratio_list: [1.0]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
L
LDOUBLEV 已提交
144
      - CopyPaste:
L
LDOUBLEV 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
      - IaaAugment:
          augmenter_args:
            - { 'type': Fliplr, 'args': { 'p': 0.5 } }
            - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
            - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
      - EastRandomCropData:
          size: [960, 960]
          max_tries: 50
          keep_ratio: true
      - MakeBorderMap:
          shrink_ratio: 0.4
          thresh_min: 0.3
          thresh_max: 0.7
      - MakeShrinkMap:
          shrink_ratio: 0.4
          min_text_size: 8
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
  loader:
    shuffle: True
    drop_last: False
    batch_size_per_card: 8
    num_workers: 4

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/icdar2015/text_localization/
    label_file_list:
      - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
      - DetResizeForTest:
#           image_shape: [736, 1280]
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 1 # must be 1
    num_workers: 2