ch_PP-OCRv2_det_cml.yml 4.8 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10
Global:
  use_gpu: true
  epoch_num: 1200
  log_smooth_window: 20
  print_batch_step: 2
  save_model_dir: ./output/ch_db_mv3/
  save_epoch_step: 1200
  # evaluation is run every 5000 iterations after the 4000th iteration
  eval_batch_step: [3000, 2000]
  cal_metric_during_train: False
L
LDOUBLEV 已提交
11
  pretrained_model: 
L
LDOUBLEV 已提交
12 13 14 15 16
  checkpoints:
  save_inference_dir:
  use_visualdl: False
  infer_img: doc/imgs_en/img_10.jpg
  save_res_path: ./output/det_db/predicts_db.txt
17 18 19
  use_amp: False
  amp_level: O2
  amp_custom_black_list: ['exp']
L
LDOUBLEV 已提交
20 21 22 23

Architecture:
  name: DistillationModel
  algorithm: Distillation
A
andyjpaddle 已提交
24
  model_type: det
L
LDOUBLEV 已提交
25
  Models:
L
LDOUBLEV 已提交
26
    Teacher:
L
LDOUBLEV 已提交
27
      pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
L
LDOUBLEV 已提交
28
      freeze_params: true
L
LDOUBLEV 已提交
29 30 31
      return_all_feats: false
      model_type: det
      algorithm: DB
L
LDOUBLEV 已提交
32
      Transform:
L
LDOUBLEV 已提交
33
      Backbone:
W
add yml  
wangjingyeye 已提交
34
        name: ResNet_vd
L
LDOUBLEV 已提交
35
        layers: 18
L
LDOUBLEV 已提交
36 37
      Neck:
        name: DBFPN
L
LDOUBLEV 已提交
38
        out_channels: 256
L
LDOUBLEV 已提交
39 40 41
      Head:
        name: DBHead
        k: 50
L
LDOUBLEV 已提交
42
    Student:
L
LDOUBLEV 已提交
43
      pretrained:
L
LDOUBLEV 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
      freeze_params: false
      return_all_feats: false
      model_type: det
      algorithm: DB
      Backbone:
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: True
      Neck:
        name: DBFPN
        out_channels: 96
      Head:
        name: DBHead
        k: 50
L
LDOUBLEV 已提交
59
    Student2:
L
LDOUBLEV 已提交
60
      pretrained:
L
LDOUBLEV 已提交
61
      freeze_params: false
L
LDOUBLEV 已提交
62 63 64 65 66
      return_all_feats: false
      model_type: det
      algorithm: DB
      Transform:
      Backbone:
L
LDOUBLEV 已提交
67 68 69 70
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: True
L
LDOUBLEV 已提交
71 72
      Neck:
        name: DBFPN
L
LDOUBLEV 已提交
73
        out_channels: 96
L
LDOUBLEV 已提交
74 75 76
      Head:
        name: DBHead
        k: 50
L
LDOUBLEV 已提交
77
    
L
LDOUBLEV 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
Loss:
  name: CombinedLoss
  loss_config_list:
  - DistillationDilaDBLoss:
      weight: 1.0
      model_name_pairs:
      - ["Student", "Teacher"]
      - ["Student2", "Teacher"]
      key: maps
      balance_loss: true
      main_loss_type: DiceLoss
      alpha: 5
      beta: 10
      ohem_ratio: 3
  - DistillationDMLLoss:
      model_name_pairs:
      - ["Student", "Student2"]
      maps_name: "thrink_maps"
      weight: 1.0
      model_name_pairs: ["Student", "Student2"]
      key: maps
  - DistillationDBLoss:
      weight: 1.0
      model_name_list: ["Student", "Student2"]
      balance_loss: true
      main_loss_type: DiceLoss
      alpha: 5
      beta: 10
      ohem_ratio: 3


Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine
    learning_rate: 0.001
    warmup_epoch: 2
  regularizer:
    name: 'L2'
    factor: 0

PostProcess:
  name: DistillationDBPostProcess
L
LDOUBLEV 已提交
123 124
  model_name: ["Student", "Student2", "Teacher"]
  # key: maps
L
LDOUBLEV 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
  thresh: 0.3
  box_thresh: 0.6
  max_candidates: 1000
  unclip_ratio: 1.5

Metric:
  name: DistillationMetric
  base_metric_name: DetMetric
  main_indicator: hmean
  key: "Student"

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/icdar2015/text_localization/
    label_file_list:
      - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
    ratio_list: [1.0]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
L
LDOUBLEV 已提交
148
      - CopyPaste:
L
LDOUBLEV 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
      - IaaAugment:
          augmenter_args:
            - { 'type': Fliplr, 'args': { 'p': 0.5 } }
            - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
            - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
      - EastRandomCropData:
          size: [960, 960]
          max_tries: 50
          keep_ratio: true
      - MakeBorderMap:
          shrink_ratio: 0.4
          thresh_min: 0.3
          thresh_max: 0.7
      - MakeShrinkMap:
          shrink_ratio: 0.4
          min_text_size: 8
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
  loader:
    shuffle: True
    drop_last: False
    batch_size_per_card: 8
    num_workers: 4

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/icdar2015/text_localization/
    label_file_list:
      - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
      - DetResizeForTest:
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 1 # must be 1
    num_workers: 2