ch_PP-OCRv2_det_cml.yml 4.9 KB
Newer Older
L
LDOUBLEV 已提交
1 2 3 4 5 6 7 8 9 10
Global:
  use_gpu: true
  epoch_num: 1200
  log_smooth_window: 20
  print_batch_step: 2
  save_model_dir: ./output/ch_db_mv3/
  save_epoch_step: 1200
  # evaluation is run every 5000 iterations after the 4000th iteration
  eval_batch_step: [3000, 2000]
  cal_metric_during_train: False
L
LDOUBLEV 已提交
11
  pretrained_model: ./pretrain_models/ch_PP-OCRv2_det_distill_train/best_accuracy
L
LDOUBLEV 已提交
12 13 14 15 16 17 18 19 20
  checkpoints:
  save_inference_dir:
  use_visualdl: False
  infer_img: doc/imgs_en/img_10.jpg
  save_res_path: ./output/det_db/predicts_db.txt

Architecture:
  name: DistillationModel
  algorithm: Distillation
A
andyjpaddle 已提交
21
  model_type: det
L
LDOUBLEV 已提交
22
  Models:
L
LDOUBLEV 已提交
23
    Teacher:
L
reset  
LDOUBLEV 已提交
24
      pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
L
LDOUBLEV 已提交
25
      freeze_params: true
L
LDOUBLEV 已提交
26 27 28
      return_all_feats: false
      model_type: det
      algorithm: DB
L
LDOUBLEV 已提交
29
      Transform:
L
LDOUBLEV 已提交
30
      Backbone:
L
LDOUBLEV 已提交
31 32
        name: ResNet
        layers: 18
L
LDOUBLEV 已提交
33 34
      Neck:
        name: DBFPN
L
LDOUBLEV 已提交
35
        out_channels: 256
L
LDOUBLEV 已提交
36 37 38
      Head:
        name: DBHead
        k: 50
L
LDOUBLEV 已提交
39
    Student:
L
reset  
LDOUBLEV 已提交
40
      pretrained:
L
LDOUBLEV 已提交
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
      freeze_params: false
      return_all_feats: false
      model_type: det
      algorithm: DB
      Backbone:
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: True
      Neck:
        name: DBFPN
        out_channels: 96
      Head:
        name: DBHead
        k: 50
L
LDOUBLEV 已提交
56
    Student2:
L
reset  
LDOUBLEV 已提交
57
      pretrained:
L
LDOUBLEV 已提交
58
      freeze_params: false
L
LDOUBLEV 已提交
59 60 61 62 63
      return_all_feats: false
      model_type: det
      algorithm: DB
      Transform:
      Backbone:
L
LDOUBLEV 已提交
64 65 66 67
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: True
L
LDOUBLEV 已提交
68 69
      Neck:
        name: DBFPN
L
LDOUBLEV 已提交
70
        out_channels: 96
L
LDOUBLEV 已提交
71 72 73
      Head:
        name: DBHead
        k: 50
L
LDOUBLEV 已提交
74
    
L
LDOUBLEV 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
Loss:
  name: CombinedLoss
  loss_config_list:
  - DistillationDilaDBLoss:
      weight: 1.0
      model_name_pairs:
      - ["Student", "Teacher"]
      - ["Student2", "Teacher"]
      key: maps
      balance_loss: true
      main_loss_type: DiceLoss
      alpha: 5
      beta: 10
      ohem_ratio: 3
  - DistillationDMLLoss:
      model_name_pairs:
      - ["Student", "Student2"]
      maps_name: "thrink_maps"
      weight: 1.0
L
LDOUBLEV 已提交
94
      # act: None
L
LDOUBLEV 已提交
95 96 97 98 99 100
      model_name_pairs: ["Student", "Student2"]
      key: maps
  - DistillationDBLoss:
      weight: 1.0
      model_name_list: ["Student", "Student2"]
      # key: maps
L
LDOUBLEV 已提交
101
      # name: DBLoss
L
LDOUBLEV 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
      balance_loss: true
      main_loss_type: DiceLoss
      alpha: 5
      beta: 10
      ohem_ratio: 3


Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine
    learning_rate: 0.001
    warmup_epoch: 2
  regularizer:
    name: 'L2'
    factor: 0

PostProcess:
  name: DistillationDBPostProcess
L
LDOUBLEV 已提交
123 124
  model_name: ["Student", "Student2", "Teacher"]
  # key: maps
L
LDOUBLEV 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
  thresh: 0.3
  box_thresh: 0.6
  max_candidates: 1000
  unclip_ratio: 1.5

Metric:
  name: DistillationMetric
  base_metric_name: DetMetric
  main_indicator: hmean
  key: "Student"

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/icdar2015/text_localization/
    label_file_list:
      - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
    ratio_list: [1.0]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
L
LDOUBLEV 已提交
148
      - CopyPaste:
L
LDOUBLEV 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
      - IaaAugment:
          augmenter_args:
            - { 'type': Fliplr, 'args': { 'p': 0.5 } }
            - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
            - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
      - EastRandomCropData:
          size: [960, 960]
          max_tries: 50
          keep_ratio: true
      - MakeBorderMap:
          shrink_ratio: 0.4
          thresh_min: 0.3
          thresh_max: 0.7
      - MakeShrinkMap:
          shrink_ratio: 0.4
          min_text_size: 8
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
  loader:
    shuffle: True
    drop_last: False
    batch_size_per_card: 8
    num_workers: 4

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/icdar2015/text_localization/
    label_file_list:
      - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
      - DetResizeForTest:
#           image_shape: [736, 1280]
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 1 # must be 1
    num_workers: 2