diff --git a/configs/rec/svtr/rec_svtr_base_8local_10global_stn_ch.yml b/configs/rec/svtr/rec_svtr_base_8local_10global_stn_ch.yml new file mode 100644 index 0000000000000000000000000000000000000000..8534e78874a8412412057485cdfa71d20a9b82f2 --- /dev/null +++ b/configs/rec/svtr/rec_svtr_base_8local_10global_stn_ch.yml @@ -0,0 +1,113 @@ +Global: + use_gpu: True + epoch_num: 100 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/rec_svtr_base_stn_ch/ + save_epoch_step: 10 + # evaluation is run every 2000 iterations after the 0th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + max_text_length: 40 + infer_mode: False + use_space_char: True + save_res_path: ./output/rec/predicts_svtr_base_ch.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.99 + epsilon: 0.00000008 + weight_decay: 0.05 + no_weight_decay_name: norm pos_embed + one_dim_param_no_weight_decay: true + lr: + name: Cosine + learning_rate: 0.0003 + warmup_epoch: 5 + +Architecture: + model_type: rec + algorithm: SVTR + Transform: + name: STN_ON + tps_inputsize: [32, 64] + tps_outputsize: [32, 320] + num_control_points: 20 + tps_margins: [0.05,0.05] + stn_activation: none + Backbone: + name: SVTRNet + img_size: [32, 320] # input size 可以尝试[64,200] + out_char_num: 40 # output char patch + out_channels: 256 # char patch dim + patch_merging: 'Conv' # 是否使用patch-merging 可选Conv Pool None + embed_dim: [128, 256, 384] # 三个阶段的sub-patch dim + depth: [3, 6, 9] # 当使用patch-merging时,控制patch-merging所在的层数,分成三阶段,每个阶段的层数 + num_heads: [4, 8, 12] # 三个阶段中的sub-patch heads + mixer: ['Local','Local','Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global'] # Local atten, Global atten, Conv + local_mixer: [[7, 11], [7, 11], [7, 11]] # local mixer的范围,7表示高度的范围,11表示宽度的范围 + prenorm: False + Neck: + name: SequenceEncoder + encoder_type: reshape + Head: + name: CTCHead + +Loss: + name: CTCLoss + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/scene_ch/ch_scene + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 128 + drop_last: True + num_workers: 2 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/scene_ch/scene_test + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 2 diff --git a/configs/rec/svtr/rec_svtr_base_8local_10global_stn_en.yml b/configs/rec/svtr/rec_svtr_base_8local_10global_stn_en.yml new file mode 100644 index 0000000000000000000000000000000000000000..2b7546c4d208235ffc07aeaea4b67918f91b67aa --- /dev/null +++ b/configs/rec/svtr/rec_svtr_base_8local_10global_stn_en.yml @@ -0,0 +1,117 @@ +Global: + use_gpu: True + epoch_num: 20 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/rec_svtr_base_stn_en/ + save_epoch_step: 1 + # evaluation is run every 2000 iterations after the 0th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_svtr_base.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.99 + epsilon: 0.00000008 + weight_decay: 0.05 + no_weight_decay_name: norm pos_embed + one_dim_param_no_weight_decay: true + lr: + name: Cosine + learning_rate: 0.00025 + warmup_epoch: 2 + +Architecture: + model_type: rec + algorithm: SVTR + Transform: + name: STN_ON + tps_inputsize: [32, 64] + tps_outputsize: [48, 160] + num_control_points: 20 + tps_margins: [0.05,0.05] + stn_activation: none + Backbone: + name: SVTRNet + img_size: [48, 160] # input size 可以尝试[64,200] + out_char_num: 40 # output char patch + out_channels: 256 # char patch dim + patch_merging: 'Conv' # 是否使用patch-merging 可选Conv Pool None + embed_dim: [128, 256, 384] # 三个阶段的sub-patch dim + depth: [3, 6, 9] # 当使用patch-merging时,控制patch-merging所在的层数,分成三阶段,每个阶段的层数 + num_heads: [4, 8, 12] # 三个阶段中的sub-patch heads + mixer: ['Local','Local','Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global'] # Local atten, Global atten, Conv + local_mixer: [[7, 11], [7, 11], [7, 11]] # local mixer的范围,7表示高度的范围,11表示宽度的范围 + last_stage: True + prenorm: False + Neck: + name: SequenceEncoder + encoder_type: reshape + Head: + name: CTCHead + +Loss: + name: CTCLoss + +PostProcess: + name: SVTRLabelDecode # SVTRLabelDecode is used for eval, please change to CTCLabelDecode when training + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - SVTRRecResizeImg: # SVTRRecResizeImg is used for eval, please change to RecResizeImg when training + character_dict_path: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 256 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - SVTRRecResizeImg: # SVTRRecResizeImg is used for eval, please change to RecResizeImg when training + character_dict_path: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 128 + num_workers: 2 diff --git a/configs/rec/svtr/rec_svtr_large_10local_11global_stn_ch.yml b/configs/rec/svtr/rec_svtr_large_10local_11global_stn_ch.yml new file mode 100644 index 0000000000000000000000000000000000000000..68f0608f013eddff22615684724cfb9c4e150b49 --- /dev/null +++ b/configs/rec/svtr/rec_svtr_large_10local_11global_stn_ch.yml @@ -0,0 +1,113 @@ +Global: + use_gpu: True + epoch_num: 100 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/rec_svtr_large_ch/ + save_epoch_step: 10 + # evaluation is run every 2000 iterations after the 0th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + max_text_length: 40 + infer_mode: False + use_space_char: True + save_res_path: ./output/rec/predicts_svtr_large_ch.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.99 + epsilon: 0.00000008 + weight_decay: 0.05 + no_weight_decay_name: norm pos_embed + one_dim_param_no_weight_decay: true + lr: + name: Cosine + learning_rate: 0.0003 + warmup_epoch: 5 + +Architecture: + model_type: rec + algorithm: SVTR + Transform: + name: STN_ON + tps_inputsize: [32, 64] + tps_outputsize: [32, 320] + num_control_points: 20 + tps_margins: [0.05,0.05] + stn_activation: none + Backbone: + name: SVTRNet + img_size: [32, 320] # input size 可以尝试[64,200] + out_char_num: 40 # output char patch + out_channels: 384 # char patch dim + patch_merging: 'Conv' # 是否使用patch-merging 可选Conv Pool None + embed_dim: [192, 256, 512] # 三个阶段的sub-patch dim + depth: [3, 9, 9] # 当使用patch-merging时,控制patch-merging所在的层数,分成三阶段,每个阶段的层数 + num_heads: [6, 8, 16] # 三个阶段中的sub-patch heads + mixer: ['Local','Local','Local','Local','Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global'] # Local atten, Global atten, Conv + local_mixer: [[7, 11], [7, 11], [7, 11]] # local mixer的范围,7表示高度的范围,11表示宽度的范围 + prenorm: False + Neck: + name: SequenceEncoder + encoder_type: reshape + Head: + name: CTCHead + +Loss: + name: CTCLoss + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/scene_ch/ch_scene + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 128 + drop_last: True + num_workers: 2 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/scene_ch/scene_test + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 2 diff --git a/configs/rec/svtr/rec_svtr_large_10local_11global_stn_en.yml b/configs/rec/svtr/rec_svtr_large_10local_11global_stn_en.yml new file mode 100644 index 0000000000000000000000000000000000000000..b995bb81a2a82c18458c8f9b44ddb4b71f59a692 --- /dev/null +++ b/configs/rec/svtr/rec_svtr_large_10local_11global_stn_en.yml @@ -0,0 +1,117 @@ +Global: + use_gpu: True + epoch_num: 20 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/rec_svtr_large_en/ + save_epoch_step: 1 + # evaluation is run every 2000 iterations after the 0th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_svtr_large.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.99 + epsilon: 0.00000008 + weight_decay: 0.05 + no_weight_decay_name: norm pos_embed + one_dim_param_no_weight_decay: true + lr: + name: Cosine + learning_rate: 0.000125 + warmup_epoch: 2 + +Architecture: + model_type: rec + algorithm: SVTR + Transform: + name: STN_ON + tps_inputsize: [32, 64] + tps_outputsize: [48, 160] + num_control_points: 20 + tps_margins: [0.05,0.05] + stn_activation: none + Backbone: + name: SVTRNet + img_size: [48, 160] # input size 可以尝试[64,200] + out_char_num: 40 # output char patch + out_channels: 384 # char patch dim + patch_merging: 'Conv' # 是否使用patch-merging 可选Conv Pool None + embed_dim: [192, 256, 512] # 三个阶段的sub-patch dim + depth: [3, 9, 9] # 当使用patch-merging时,控制patch-merging所在的层数,分成三阶段,每个阶段的层数 + num_heads: [6, 8, 16] # 三个阶段中的sub-patch heads + mixer: ['Local','Local','Local','Local','Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global','Global'] # Local atten, Global atten, Conv + local_mixer: [[7, 11], [7, 11], [7, 11]] # local mixer的范围,7表示高度的范围,11表示宽度的范围 + prenorm: false + Neck: + name: SequenceEncoder + encoder_type: reshape + Head: + name: CTCHead + +Loss: + name: CTCLoss + +PostProcess: + name: SVTRLabelDecode # SVTRLabelDecode is used for eval, please change to CTCLabelDecode when training + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - RecAug: + - CTCLabelEncode: # Class handling label + - SVTRRecResizeImg: # SVTRRecResizeImg is used for eval, please change to RecResizeImg when training + character_dict_path: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 128 + drop_last: True + num_workers: 2 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - SVTRRecResizeImg: # SVTRRecResizeImg is used for eval, please change to RecResizeImg when training + character_dict_path: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 128 + num_workers: 2 diff --git a/configs/rec/svtr/rec_svtr_small_8local_7global_stn_ch.yml b/configs/rec/svtr/rec_svtr_small_8local_7global_stn_ch.yml new file mode 100644 index 0000000000000000000000000000000000000000..38ddb0e4fc8abd5405d1d7e370fb982797656fa3 --- /dev/null +++ b/configs/rec/svtr/rec_svtr_small_8local_7global_stn_ch.yml @@ -0,0 +1,114 @@ +Global: + use_gpu: True + epoch_num: 100 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/rec_svtr_small_ch/ + save_epoch_step: 10 + # evaluation is run every 2000 iterations after the 0th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + max_text_length: 40 + infer_mode: False + use_space_char: True + save_res_path: ./output/rec/predicts_svtr_small_ch.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.99 + epsilon: 0.00000008 + weight_decay: 0.05 + no_weight_decay_name: norm pos_embed + one_dim_param_no_weight_decay: true + lr: + name: Cosine + learning_rate: 0.0003 + warmup_epoch: 5 + +Architecture: + model_type: rec + algorithm: SVTR + Transform: + name: STN_ON + tps_inputsize: [32, 64] + tps_outputsize: [32, 320] + num_control_points: 20 + tps_margins: [0.05,0.05] + stn_activation: none + Backbone: + name: SVTRNet + img_size: [32, 320] # input size 可以尝试[64,200] + out_char_num: 40 # output char patch + out_channels: 192 # char patch dim + patch_merging: 'Conv' # 是否使用patch-merging 可选Conv Pool None + embed_dim: [96, 192, 256] # 三个阶段的sub-patch dim + depth: [3, 6, 6] # 当使用patch-merging时,控制patch-merging所在的层数,分成三阶段,每个阶段的层数 + num_heads: [3, 6, 8] # 三个阶段中的sub-patch heads + mixer: ['Local','Local','Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global','Global'] # Local atten, Global atten, Conv + local_mixer: [[7, 11], [7, 11], [7, 11]] # local mixer的范围,7表示高度的范围,11表示宽度的范围 + last_stage: True + prenorm: False + Neck: + name: SequenceEncoder + encoder_type: reshape + Head: + name: CTCHead + +Loss: + name: CTCLoss + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/scene_ch/ch_scene + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 128 + drop_last: True + num_workers: 2 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/scene_ch/scene_test + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 2 diff --git a/configs/rec/svtr/rec_svtr_small_8local_7global_stn_en.yml b/configs/rec/svtr/rec_svtr_small_8local_7global_stn_en.yml new file mode 100644 index 0000000000000000000000000000000000000000..69fea2e1f6de4821d3468d325340855fdd2b9a86 --- /dev/null +++ b/configs/rec/svtr/rec_svtr_small_8local_7global_stn_en.yml @@ -0,0 +1,117 @@ +Global: + use_gpu: True + epoch_num: 20 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/rec_svtr_small_stn_en/ + save_epoch_step: 1 + # evaluation is run every 2000 iterations after the 0th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_svtr_small.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.99 + epsilon: 0.00000008 + weight_decay: 0.05 + no_weight_decay_name: norm pos_embed + one_dim_param_no_weight_decay: true + lr: + name: Cosine + learning_rate: 0.0005 + warmup_epoch: 2 + +Architecture: + model_type: rec + algorithm: SVTR + Transform: + name: STN_ON + tps_inputsize: [32, 64] + tps_outputsize: [32, 100] + num_control_points: 20 + tps_margins: [0.05,0.05] + stn_activation: none + Backbone: + name: SVTRNet + img_size: [32, 100] # input size 可以尝试[64,200] + out_char_num: 25 # output char patch + out_channels: 192 # char patch dim + patch_merging: 'Conv' # 是否使用patch-merging 可选Conv Pool None + embed_dim: [96, 192, 256] # 三个阶段的sub-patch dim + depth: [3, 6, 6] # 当使用patch-merging时,控制patch-merging所在的层数,分成三阶段,每个阶段的层数 + num_heads: [3, 6, 8] # 三个阶段中的sub-patch heads + mixer: ['Local','Local','Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global','Global'] # Local atten, Global atten, Conv + local_mixer: [[7, 11], [7, 11], [7, 11]] # local mixer的范围,7表示高度的范围,11表示宽度的范围 + last_stage: True + prenorm: False + Neck: + name: SequenceEncoder + encoder_type: reshape + Head: + name: CTCHead + +Loss: + name: CTCLoss + +PostProcess: + name: SVTRLabelDecode # SVTRLabelDecode is used for eval, please change to CTCLabelDecode when training + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - SVTRRecResizeImg: # SVTRRecResizeImg is used for eval, please change to RecResizeImg when training + character_dict_path: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 512 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - SVTRRecResizeImg: # SVTRRecResizeImg is used for eval, please change to RecResizeImg when training + character_dict_path: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 2 diff --git a/configs/rec/svtr/rec_svtr_tiny_6local_6global_stn_ch.yml b/configs/rec/svtr/rec_svtr_tiny_6local_6global_stn_ch.yml new file mode 100644 index 0000000000000000000000000000000000000000..e0d77f632cf125d9b8a10647035ff3c48d2ebf39 --- /dev/null +++ b/configs/rec/svtr/rec_svtr_tiny_6local_6global_stn_ch.yml @@ -0,0 +1,114 @@ +Global: + use_gpu: True + epoch_num: 100 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/rec_svtr_tiny_ch/ + save_epoch_step: 10 + # evaluation is run every 2000 iterations after the 0th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: ppocr/utils/ppocr_keys_v1.txt + max_text_length: 40 + infer_mode: False + use_space_char: True + save_res_path: ./output/rec/predicts_svtr_tiny_ch.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.99 + epsilon: 0.00000008 + weight_decay: 0.05 + no_weight_decay_name: norm pos_embed + one_dim_param_no_weight_decay: true + lr: + name: Cosine + learning_rate: 0.0003 + warmup_epoch: 5 + +Architecture: + model_type: rec + algorithm: SVTR + Transform: + name: STN_ON + tps_inputsize: [32, 64] + tps_outputsize: [32, 320] + num_control_points: 20 + tps_margins: [0.05,0.05] + stn_activation: none + Backbone: + name: SVTRNet + img_size: [32, 320] # input size + out_char_num: 40 # number char patch + out_channels: 192 # char patch dim + patch_merging: 'Conv' # 是否使用patch-merging 可选Conv Pool None + embed_dim: [64, 128, 256] # 三个阶段的sub-patch dim + depth: [3, 6, 3] # 当使用patch-merging时,控制patch-merging所在的层数,分成三阶段,每个阶段的层数 + num_heads: [2, 4, 8] # 三个阶段中的sub-patch heads + mixer: ['Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global'] # Local atten, Global atten, Conv + local_mixer: [[7, 11], [7, 11], [7, 11]] # local mixer的范围,7表示高度的范围,11表示宽度的范围 + last_stage: True # 三个阶段中的sub-patch heads + prenorm: false + Neck: + name: SequenceEncoder + encoder_type: reshape + Head: + name: CTCHead + +Loss: + name: CTCLoss + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/scene_ch/ch_scene + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 128 + drop_last: True + num_workers: 2 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/scene_ch/scene_test + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 2 diff --git a/configs/rec/svtr/rec_svtr_tiny_6local_6global_stn_en.yml b/configs/rec/svtr/rec_svtr_tiny_6local_6global_stn_en.yml new file mode 100644 index 0000000000000000000000000000000000000000..2bd9970b08b832d161b73efb38ee510fbe2ef280 --- /dev/null +++ b/configs/rec/svtr/rec_svtr_tiny_6local_6global_stn_en.yml @@ -0,0 +1,117 @@ +Global: + use_gpu: True + epoch_num: 20 + log_smooth_window: 20 + print_batch_step: 10 + save_model_dir: ./output/rec/rec_svtr_tiny_en/ + save_epoch_step: 1 + # evaluation is run every 2000 iterations after the 0th iteration + eval_batch_step: [0, 2000] + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words_en/word_10.png + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + infer_mode: False + use_space_char: False + save_res_path: ./output/rec/predicts_svtr_tiny.txt + + +Optimizer: + name: AdamW + beta1: 0.9 + beta2: 0.99 + epsilon: 0.00000008 + weight_decay: 0.05 + no_weight_decay_name: norm pos_embed + one_dim_param_no_weight_decay: true + lr: + name: Cosine + learning_rate: 0.0005 + warmup_epoch: 2 + +Architecture: + model_type: rec + algorithm: SVTR + Transform: + name: STN_ON + tps_inputsize: [32, 64] + tps_outputsize: [32, 100] + num_control_points: 20 + tps_margins: [0.05,0.05] + stn_activation: none + Backbone: + name: SVTRNet + img_size: [32, 100] # input size 可以尝试[64,200] + out_char_num: 25 # output char patch + out_channels: 192 # char patch dim + patch_merging: 'Conv' # 是否使用patch-merging 可选Conv Pool None + embed_dim: [64, 128, 256] # 三个阶段的sub-patch dim + depth: [3, 6, 3] # 当使用patch-merging时,控制patch-merging所在的层数,分成三阶段,每个阶段的层数 + num_heads: [2, 4, 8] # 三个阶段中的sub-patch heads + mixer: ['Local','Local','Local','Local','Local','Local','Global','Global','Global','Global','Global','Global'] # Local atten, Global atten, Conv + local_mixer: [[7, 11], [7, 11], [7, 11]] # local mixer的范围,7表示高度的范围,11表示宽度的范围 + last_stage: True # 三个阶段中的sub-patch heads + prenorm: false + Neck: + name: SequenceEncoder + encoder_type: reshape + Head: + name: CTCHead + +Loss: + name: CTCLoss + +PostProcess: + name: CTCLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/training/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + character_dict_path: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: True + batch_size_per_card: 512 + drop_last: True + num_workers: 2 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation/ + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - CTCLabelEncode: # Class handling label + - RecResizeImg: + character_dict_path: + image_shape: [3, 64, 256] + padding: False + - KeepKeys: + keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order + loader: + shuffle: False + drop_last: False + batch_size_per_card: 256 + num_workers: 2 diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index c24886aa89dbd0957a313dded862de5893fb6817..7580e607afb356a1032c4d6b2d2267bff608a80d 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -23,7 +23,7 @@ from .random_crop_data import EastRandomCropData, RandomCropImgMask from .make_pse_gt import MakePseGt from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ - SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg + SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, SVTRRecResizeImg from .randaugment import RandAugment from .copy_paste import CopyPaste from .ColorJitter import ColorJitter diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 960a11be16a9090d80b5c5a27069246d1bcaa3e7..e3fb4d7eb3a572e4f91da91c1f219e3c30276798 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -16,6 +16,7 @@ import math import cv2 import numpy as np import random +import copy from PIL import Image from .text_image_aug import tia_perspective, tia_stretch, tia_distort @@ -206,6 +207,27 @@ class PRENResizeImg(object): return data +class SVTRRecResizeImg(object): + def __init__(self, + image_shape, + infer_mode=False, + character_dict_path='./ppocr/utils/ppocr_keys_v1.txt', + padding=True, + **kwargs): + self.image_shape = image_shape + self.infer_mode = infer_mode + self.character_dict_path = character_dict_path + self.padding = padding + + + def __call__(self, data): + img = data['image'] + norm_img = resize_norm_img_svtr(img, self.image_shape, self.padding) + data['image'] = norm_img + return data + + + def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25): imgC, imgH, imgW_min, imgW_max = image_shape h = img.shape[0] @@ -324,6 +346,60 @@ def resize_norm_img_srn(img, image_shape): return np.reshape(img_black, (c, row, col)).astype(np.float32) + +def resize_norm_img_svtr(img, image_shape, padding=True): + imgC, imgH, imgW = image_shape + h = img.shape[0] + w = img.shape[1] + if not padding: + + if h > 2.0 * w: + image = Image.fromarray(img) + image1 = image.rotate(90, expand=True) + image2 = image.rotate(-90, expand=True) + img1 = np.array(image1) + img2 = np.array(image2) + else: + img1 = copy.deepcopy(img) + img2 = copy.deepcopy(img) + + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image1 = cv2.resize( + img1, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image2 = cv2.resize( + img2, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_w = imgW + else: + ratio = w / float(h) + if math.ceil(imgH * ratio) > imgW: + resized_w = imgW + else: + resized_w = int(math.ceil(imgH * ratio)) + resized_image = cv2.resize(img, (resized_w, imgH)) + resized_image = resized_image.astype('float32') + resized_image1 = resized_image1.astype('float32') + resized_image2 = resized_image2.astype('float32') + if image_shape[0] == 1: + resized_image = resized_image / 255 + resized_image = resized_image[np.newaxis, :] + else: + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image1 = resized_image1.transpose((2, 0, 1)) / 255 + resized_image2 = resized_image2.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + resized_image1 -= 0.5 + resized_image1 /= 0.5 + resized_image2 -= 0.5 + resized_image2 /= 0.5 + padding_im = np.zeros((3, imgC, imgH, imgW), dtype=np.float32) + padding_im[0, :, :, 0:resized_w] = resized_image + padding_im[1, :, :, 0:resized_w] = resized_image1 + padding_im[2, :, :, 0:resized_w] = resized_image2 + return padding_im + + def srn_other_inputs(image_shape, num_heads, max_text_length): imgC, imgH, imgW = image_shape diff --git a/ppocr/modeling/backbones/rec_svtrnet.py b/ppocr/modeling/backbones/rec_svtrnet.py index bef8f3688d8efaa113a81d1acb135273746e6a93..2b2c6c6de50bb5dde0887cbe3bade636792b7f29 100644 --- a/ppocr/modeling/backbones/rec_svtrnet.py +++ b/ppocr/modeling/backbones/rec_svtrnet.py @@ -455,7 +455,7 @@ class SVTRNet(nn.Layer): qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, - act_layer=nn.Swish, + act_layer=eval(act), attn_drop=attn_drop_rate, drop_path=dpr[0:depth[0]][i], norm_layer=norm_layer, diff --git a/ppocr/modeling/transforms/stn.py b/ppocr/modeling/transforms/stn.py index 6f2bdda050f217d8253740001901fbff4065782a..1b15d5b8a7b7a1b1ab686d20acea750437463939 100644 --- a/ppocr/modeling/transforms/stn.py +++ b/ppocr/modeling/transforms/stn.py @@ -128,6 +128,8 @@ class STN_ON(nn.Layer): self.out_channels = in_channels def forward(self, image): + if len(image.shape)==5: + image = image.reshape([0, image.shape[-3], image.shape[-2], image.shape[-1]]) stn_input = paddle.nn.functional.interpolate( image, self.tps_inputsize, mode="bilinear", align_corners=True) stn_img_feat, ctrl_points = self.stn_head(stn_input) diff --git a/ppocr/modeling/transforms/tps_spatial_transformer.py b/ppocr/modeling/transforms/tps_spatial_transformer.py index 043bb56b8a526c12b2e0799bf41e128c6499c1fc..cb1cb10aaa98dffa2f720dc81afdf82d25e071ca 100644 --- a/ppocr/modeling/transforms/tps_spatial_transformer.py +++ b/ppocr/modeling/transforms/tps_spatial_transformer.py @@ -138,9 +138,9 @@ class TPSSpatialTransformer(nn.Layer): assert source_control_points.shape[2] == 2 batch_size = paddle.shape(source_control_points)[0] - self.padding_matrix = paddle.expand( + padding_matrix = paddle.expand( self.padding_matrix, shape=[batch_size, 3, 2]) - Y = paddle.concat([source_control_points, self.padding_matrix], 1) + Y = paddle.concat([source_control_points, padding_matrix], 1) mapping_matrix = paddle.matmul(self.inverse_kernel, Y) source_coordinate = paddle.matmul(self.target_coordinate_repr, mapping_matrix) diff --git a/ppocr/optimizer/__init__.py b/ppocr/optimizer/__init__.py index 4110fb47678583cff826a9bc855b3fb378a533f9..a6bd2ebb4a81427245dc10e446cd2da101d53bd4 100644 --- a/ppocr/optimizer/__init__.py +++ b/ppocr/optimizer/__init__.py @@ -30,7 +30,7 @@ def build_lr_scheduler(lr_config, epochs, step_each_epoch): return lr -def build_optimizer(config, epochs, step_each_epoch, parameters): +def build_optimizer(config, epochs, step_each_epoch, model): from . import regularizer, optimizer config = copy.deepcopy(config) # step1 build lr @@ -43,6 +43,8 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): if not hasattr(regularizer, reg_name): reg_name += 'Decay' reg = getattr(regularizer, reg_name)(**reg_config)() + elif 'weight_decay' in config: + reg = config.pop('weight_decay') else: reg = None @@ -57,4 +59,4 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): weight_decay=reg, grad_clip=grad_clip, **config) - return optim(parameters), lr + return optim(model), lr diff --git a/ppocr/optimizer/optimizer.py b/ppocr/optimizer/optimizer.py index b98081227e180edbf023a8b5b7a0b82bb7c631e5..c450a3a3684eb44cdc758a2b27783b5a81945c38 100644 --- a/ppocr/optimizer/optimizer.py +++ b/ppocr/optimizer/optimizer.py @@ -42,13 +42,13 @@ class Momentum(object): self.weight_decay = weight_decay self.grad_clip = grad_clip - def __call__(self, parameters): + def __call__(self, model): opt = optim.Momentum( learning_rate=self.learning_rate, momentum=self.momentum, weight_decay=self.weight_decay, grad_clip=self.grad_clip, - parameters=parameters) + parameters=model.parameters()) return opt @@ -75,7 +75,7 @@ class Adam(object): self.name = name self.lazy_mode = lazy_mode - def __call__(self, parameters): + def __call__(self, model): opt = optim.Adam( learning_rate=self.learning_rate, beta1=self.beta1, @@ -85,7 +85,7 @@ class Adam(object): grad_clip=self.grad_clip, name=self.name, lazy_mode=self.lazy_mode, - parameters=parameters) + parameters=model.parameters()) return opt @@ -117,7 +117,7 @@ class RMSProp(object): self.weight_decay = weight_decay self.grad_clip = grad_clip - def __call__(self, parameters): + def __call__(self, model): opt = optim.RMSProp( learning_rate=self.learning_rate, momentum=self.momentum, @@ -125,7 +125,7 @@ class RMSProp(object): epsilon=self.epsilon, weight_decay=self.weight_decay, grad_clip=self.grad_clip, - parameters=parameters) + parameters=model.parameters()) return opt @@ -148,7 +148,7 @@ class Adadelta(object): self.grad_clip = grad_clip self.name = name - def __call__(self, parameters): + def __call__(self, model): opt = optim.Adadelta( learning_rate=self.learning_rate, epsilon=self.epsilon, @@ -156,7 +156,7 @@ class Adadelta(object): weight_decay=self.weight_decay, grad_clip=self.grad_clip, name=self.name, - parameters=parameters) + parameters=model.parameters()) return opt @@ -165,31 +165,55 @@ class AdamW(object): learning_rate=0.001, beta1=0.9, beta2=0.999, - epsilon=1e-08, + epsilon=1e-8, weight_decay=0.01, + multi_precision=False, grad_clip=None, + no_weight_decay_name=None, + one_dim_param_no_weight_decay=False, name=None, lazy_mode=False, - **kwargs): + **args): + super().__init__() self.learning_rate = learning_rate self.beta1 = beta1 self.beta2 = beta2 self.epsilon = epsilon - self.learning_rate = learning_rate + self.grad_clip = grad_clip self.weight_decay = 0.01 if weight_decay is None else weight_decay self.grad_clip = grad_clip self.name = name self.lazy_mode = lazy_mode - - def __call__(self, parameters): + self.multi_precision = multi_precision + self.no_weight_decay_name_list = no_weight_decay_name.split( + ) if no_weight_decay_name else [] + self.one_dim_param_no_weight_decay = one_dim_param_no_weight_decay + + def __call__(self, model): + parameters = model.parameters() + + self.no_weight_decay_param_name_list = [ + p.name for n, p in model.named_parameters() if any(nd in n for nd in self.no_weight_decay_name_list) + ] + + if self.one_dim_param_no_weight_decay: + self.no_weight_decay_param_name_list += [ + p.name for n, p in model.named_parameters() if len(p.shape) == 1 + ] + opt = optim.AdamW( learning_rate=self.learning_rate, beta1=self.beta1, beta2=self.beta2, epsilon=self.epsilon, + parameters=parameters, weight_decay=self.weight_decay, + multi_precision=self.multi_precision, grad_clip=self.grad_clip, name=self.name, lazy_mode=self.lazy_mode, - parameters=parameters) + apply_decay_param_fun=self._apply_decay_param_fun) return opt + + def _apply_decay_param_fun(self, name): + return name not in self.no_weight_decay_param_name_list \ No newline at end of file diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index f50b5f1c5f8e617066bb47636c8f4d2b171b6ecb..390f6f4560f9814a3af757a4fd16c55fe93d01f9 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -27,7 +27,7 @@ from .sast_postprocess import SASTPostProcess from .fce_postprocess import FCEPostProcess from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ - SEEDLabelDecode, PRENLabelDecode + SEEDLabelDecode, PRENLabelDecode, SVTRLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess @@ -42,7 +42,7 @@ def build_post_process(config, global_config=None): 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', - 'DistillationSARLabelDecode' + 'DistillationSARLabelDecode', 'SVTRLabelDecode' ] if config['name'] == 'PSEPostProcess': diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index bf0fd890bf25949361665d212bf8e1a657054e5b..50f11f899fb4dd49da75199095772a92cc4a8d7b 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -752,3 +752,40 @@ class PRENLabelDecode(BaseRecLabelDecode): return text label = self.decode(label) return text, label + + +class SVTRLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, character_dict_path=None, use_space_char=False, + **kwargs): + super(SVTRLabelDecode, self).__init__(character_dict_path, + use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + if isinstance(preds, tuple): + preds = preds[-1] + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + preds_idx = preds.argmax(axis=-1) + preds_prob = preds.max(axis=-1) + + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + return_text = [] + for i in range(0, len(text), 3): + text0 = text[i] + text1 = text[i + 1] + text2 = text[i + 2] + + text_pred = [text0[0], text1[0], text2[0]] + text_prob = [text0[1], text1[1], text2[1]] + id_max = text_prob.index(max(text_prob)) + return_text.append((text_pred[id_max], text_prob[id_max])) + if label is None: + return return_text + label = self.decode(label) + return return_text, label + + def add_special_char(self, dict_character): + dict_character = ['blank'] + dict_character + return dict_character \ No newline at end of file diff --git a/tools/export_model.py b/tools/export_model.py index 96cc05a2449ca005a4ea5767fc64c777a90d6114..003bc61f791b6c41a3b08d58ab87f12109744f9a 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -61,6 +61,11 @@ def export_single_model(model, arch_config, save_path, logger): paddle.static.InputSpec( shape=[None, 3, 48, -1], dtype="float32"), ] + else: + other_shape = [ + paddle.static.InputSpec( + shape=[None, 3, 64, 256], dtype="float32"), + ] model = to_static(model, input_spec=other_shape) elif arch_config["algorithm"] == "PREN": other_shape = [ diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index d4fbc3888ce4b42138fa3eda7774156e7e751fcd..2abc0220937175f95ee4c1e4b0b949d24d5fa3e8 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -131,6 +131,17 @@ class TextRecognizer(object): padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32) padding_im[:, :, 0:resized_w] = resized_image return padding_im + + def resize_norm_img_svtr(self, img, image_shape): + + imgC, imgH, imgW = image_shape + resized_image = cv2.resize( + img, (imgW, imgH), interpolation=cv2.INTER_LINEAR) + resized_image = resized_image.astype('float32') + resized_image = resized_image.transpose((2, 0, 1)) / 255 + resized_image -= 0.5 + resized_image /= 0.5 + return resized_image def resize_norm_img_srn(self, img, image_shape): imgC, imgH, imgW = image_shape @@ -263,12 +274,8 @@ class TextRecognizer(object): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR": - norm_img = self.resize_norm_img(img_list[indices[ino]], - max_wh_ratio) - norm_img = norm_img[np.newaxis, :] - norm_img_batch.append(norm_img) - elif self.rec_algorithm == "SAR": + + if self.rec_algorithm == "SAR": norm_img, _, _, valid_ratio = self.resize_norm_img_sar( img_list[indices[ino]], self.rec_image_shape) norm_img = norm_img[np.newaxis, :] @@ -276,7 +283,7 @@ class TextRecognizer(object): valid_ratios = [] valid_ratios.append(valid_ratio) norm_img_batch.append(norm_img) - else: + elif self.rec_algorithm == "SRN": norm_img = self.process_image_srn( img_list[indices[ino]], self.rec_image_shape, 8, 25) encoder_word_pos_list = [] @@ -288,6 +295,16 @@ class TextRecognizer(object): gsrm_slf_attn_bias1_list.append(norm_img[3]) gsrm_slf_attn_bias2_list.append(norm_img[4]) norm_img_batch.append(norm_img[0]) + elif self.rec_algorithm == "SVTR": + norm_img = self.resize_norm_img_svtr( + img_list[indices[ino]], self.rec_image_shape) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + else: + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() if self.benchmark: diff --git a/tools/train.py b/tools/train.py index 77e600ab6641f0baade072bb853db0d6d44052a4..42aba548d6bf5fc35f033ef2baca0fb54d79e75a 100755 --- a/tools/train.py +++ b/tools/train.py @@ -129,7 +129,7 @@ def main(config, device, logger, vdl_writer): config['Optimizer'], epochs=config['Global']['epoch_num'], step_each_epoch=len(train_dataloader), - parameters=model.parameters()) + model=model) # build metric eval_class = build_metric(config['Metric'])