ch_PP-OCRv3_rec_distillation.yml 4.6 KB
Newer Older
A
andyjpaddle 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
Global:
  debug: false
  use_gpu: true
  epoch_num: 800
  log_smooth_window: 20
  print_batch_step: 10
  save_model_dir: ./output/rec_ppocr_v3_distillation
  save_epoch_step: 3
  eval_batch_step: [0, 2000]
  cal_metric_during_train: true
  pretrained_model:
  checkpoints:
  save_inference_dir:
  use_visualdl: false
  infer_img: doc/imgs_words/ch/word_1.jpg
  character_dict_path: ppocr/utils/ppocr_keys_v1.txt
  max_text_length: &max_text_length 25
  infer_mode: false
  use_space_char: true
  distributed: true
  save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt
22
  d2s_train_image_shape: [3, 48, -1]
A
andyjpaddle 已提交
23 24 25 26 27 28 29 30


Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Piecewise
31
    decay_epochs : [700]
A
andyjpaddle 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
    values : [0.0005, 0.00005]
    warmup_epoch: 5
  regularizer:
    name: L2
    factor: 3.0e-05


Architecture:
  model_type: &model_type "rec"
  name: DistillationModel
  algorithm: Distillation
  Models:
    Teacher:
      pretrained:
      freeze_params: false
      return_all_feats: true
      model_type: *model_type
49
      algorithm: SVTR_LCNet
A
andyjpaddle 已提交
50 51 52 53 54 55
      Transform:
      Backbone:
        name: MobileNetV1Enhance
        scale: 0.5
        last_conv_stride: [1, 2]
        last_pool_type: avg
56
        last_pool_kernel_size: [2, 2]
A
andyjpaddle 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
      Head:
        name: MultiHead
        head_list:
          - CTCHead:
              Neck:
                name: svtr
                dims: 64
                depth: 2
                hidden_dims: 120
                use_guide: True
              Head:
                fc_decay: 0.00001
          - SARHead:
              enc_dim: 512
              max_text_length: *max_text_length
    Student:
      pretrained:
      freeze_params: false
      return_all_feats: true
      model_type: *model_type
77
      algorithm: SVTR_LCNet
A
andyjpaddle 已提交
78 79 80 81 82 83
      Transform:
      Backbone:
        name: MobileNetV1Enhance
        scale: 0.5
        last_conv_stride: [1, 2]
        last_pool_type: avg
84
        last_pool_kernel_size: [2, 2]
A
andyjpaddle 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
      Head:
        name: MultiHead
        head_list:
          - CTCHead:
              Neck:
                name: svtr
                dims: 64
                depth: 2
                hidden_dims: 120
                use_guide: True
              Head:
                fc_decay: 0.00001
          - SARHead:
              enc_dim: 512
              max_text_length: *max_text_length
Loss:
  name: CombinedLoss
  loss_config_list:
  - DistillationDMLLoss:
      weight: 1.0
      act: "softmax"
      use_log: true
      model_name_pairs:
      - ["Student", "Teacher"]
      key: head_out
      multi_head: True
      dis_head: ctc
      name: dml_ctc
  - DistillationDMLLoss:
      weight: 0.5
      act: "softmax"
      use_log: true
      model_name_pairs:
      - ["Student", "Teacher"]
      key: head_out
      multi_head: True
      dis_head: sar
      name: dml_sar
  - DistillationDistanceLoss:
      weight: 1.0
      mode: "l2"
      model_name_pairs:
      - ["Student", "Teacher"]
      key: backbone_out
  - DistillationCTCLoss:
      weight: 1.0
      model_name_list: ["Student", "Teacher"]
      key: head_out
      multi_head: True
  - DistillationSARLoss:
littletomatodonkey's avatar
littletomatodonkey 已提交
135
      weight: 1.0
A
andyjpaddle 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
      model_name_list: ["Student", "Teacher"]
      key: head_out
      multi_head: True

PostProcess:
  name: DistillationCTCLabelDecode
  model_name: ["Student", "Teacher"]
  key: head_out
  multi_head: True

Metric:
  name: DistillationMetric
  base_metric_name: RecMetric
  main_indicator: acc
  key: "Student"
littletomatodonkey's avatar
littletomatodonkey 已提交
151
  ignore_space: False
A
andyjpaddle 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/
    ext_op_transform_idx: 1
    label_file_list:
    - ./train_data/train_list.txt
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - RecConAug:
        prob: 0.5
        ext_data_num: 2
        image_shape: [48, 320, 3]
A
andyjpaddle 已提交
168
        max_text_length: *max_text_length
A
andyjpaddle 已提交
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
    - RecAug:
    - MultiLabelEncode:
    - RecResizeImg:
        image_shape: [3, 48, 320]
    - KeepKeys:
        keep_keys:
        - image
        - label_ctc
        - label_sar
        - length
        - valid_ratio
  loader:
    shuffle: true
    batch_size_per_card: 128
    drop_last: true
    num_workers: 4
Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data
    label_file_list:
    - ./train_data/val_list.txt
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - MultiLabelEncode:
    - RecResizeImg:
        image_shape: [3, 48, 320]
    - KeepKeys:
        keep_keys:
        - image
        - label_ctc
        - label_sar
        - length
        - valid_ratio
  loader:
    shuffle: false
    drop_last: false
    batch_size_per_card: 128
    num_workers: 4