det_r18_vd_ct.yml 2.4 KB
Newer Older
H
huangjun12 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
Global:
  use_gpu: true
  epoch_num: 600
  log_smooth_window: 20
  print_batch_step: 10
  save_model_dir: ./output/det_ct/
  save_epoch_step: 10
  # evaluation is run every 2000 iterations
  eval_batch_step: [0,1000]
  cal_metric_during_train: False
  pretrained_model: ./pretrain_models/ResNet18_vd_pretrained.pdparams
  checkpoints:
  save_inference_dir:
  use_visualdl: False
  infer_img: doc/imgs_en/img623.jpg
  save_res_path: ./output/det_ct/predicts_ct.txt

Architecture:
  model_type: det
  algorithm: CT
  Transform:
  Backbone:
    name: ResNet_vd
    layers: 18
  Neck:
    name: CTFPN
  Head:
    name: CT_Head
    in_channels: 512
    hidden_dim: 128
    num_classes: 3

Loss:
  name: CTLoss

Optimizer:
  name: Adam
  lr:  #PolynomialDecay
    name: Linear 
    learning_rate: 0.001
    end_lr: 0.
    epochs: 600
    step_each_epoch: 1254
    power: 0.9

PostProcess:
  name: CTPostProcess
  box_type: poly

Metric:
  name: CTMetric
  main_indicator: f_score

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/total_text/train
    label_file_list:
      - ./train_data/total_text/train/train.txt
    ratio_list: [1.0]
    transforms:
      - DecodeImage:
          img_mode: RGB
          channel_first: False
      - CTLabelEncode: # Class handling label
      - RandomScale:
      - MakeShrink:
      - GroupRandomHorizontalFlip:
      - GroupRandomRotate:
      - GroupRandomCropPadding:
      - MakeCentripetalShift:
      - ColorJitter:
          brightness: 0.125
          saturation: 0.5 
      - ToCHWImage: 
      - NormalizeImage:
      - KeepKeys:
          keep_keys: ['image', 'gt_kernel', 'training_mask', 'gt_instance', 'gt_kernel_instance', 'training_mask_distance', 'gt_distance'] # the order of the dataloader list
  loader:
    shuffle: True
    drop_last: True
    batch_size_per_card: 4
    num_workers: 8

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/total_text/test
    label_file_list:
      - ./train_data/total_text/test/test.txt
    ratio_list: [1.0]
    transforms:
      - DecodeImage:
          img_mode: RGB
          channel_first: False
      - CTLabelEncode: # Class handling label
      - ScaleAlignedShort:
      - NormalizeImage:
          order: 'hwc'
      - ToCHWImage: 
      - KeepKeys:
          keep_keys: ['image', 'shape', 'polys', 'texts'] # the order of the dataloader list          
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 1
    num_workers: 2