提交 0cc083f6 编写于 作者: T tink2123

mv drop last in det model

上级 b55b8eda
...@@ -9,7 +9,6 @@ Global: ...@@ -9,7 +9,6 @@ Global:
eval_batch_step: 5000 eval_batch_step: 5000
train_batch_size_per_card: 16 train_batch_size_per_card: 16
test_batch_size_per_card: 16 test_batch_size_per_card: 16
drop_last: false
image_shape: [3, 640, 640] image_shape: [3, 640, 640]
reader_yml: ./configs/det/det_db_icdar15_reader.yml reader_yml: ./configs/det/det_db_icdar15_reader.yml
pretrain_weights: ./pretrain_models/MobileNetV3_large_x0_5_pretrained/ pretrain_weights: ./pretrain_models/MobileNetV3_large_x0_5_pretrained/
......
...@@ -9,7 +9,6 @@ Global: ...@@ -9,7 +9,6 @@ Global:
eval_batch_step: 5000 eval_batch_step: 5000
train_batch_size_per_card: 16 train_batch_size_per_card: 16
test_batch_size_per_card: 16 test_batch_size_per_card: 16
drop_last: false
image_shape: [3, 512, 512] image_shape: [3, 512, 512]
reader_yml: ./configs/det/det_east_icdar15_reader.yml reader_yml: ./configs/det/det_east_icdar15_reader.yml
pretrain_weights: ./pretrain_models/MobileNetV3_large_x0_5_pretrained/ pretrain_weights: ./pretrain_models/MobileNetV3_large_x0_5_pretrained/
......
...@@ -10,7 +10,6 @@ Global: ...@@ -10,7 +10,6 @@ Global:
train_batch_size_per_card: 8 train_batch_size_per_card: 8
test_batch_size_per_card: 16 test_batch_size_per_card: 16
image_shape: [3, 640, 640] image_shape: [3, 640, 640]
drop_last: false
reader_yml: ./configs/det/det_db_icdar15_reader.yml reader_yml: ./configs/det/det_db_icdar15_reader.yml
pretrain_weights: ./pretrain_models/ResNet50_vd_ssld_pretrained/ pretrain_weights: ./pretrain_models/ResNet50_vd_ssld_pretrained/
save_res_path: ./output/det_db/predicts_db.txt save_res_path: ./output/det_db/predicts_db.txt
......
...@@ -10,7 +10,6 @@ Global: ...@@ -10,7 +10,6 @@ Global:
train_batch_size_per_card: 8 train_batch_size_per_card: 8
test_batch_size_per_card: 16 test_batch_size_per_card: 16
image_shape: [3, 512, 512] image_shape: [3, 512, 512]
drop_last: false
reader_yml: ./configs/det/det_east_icdar15_reader.yml reader_yml: ./configs/det/det_east_icdar15_reader.yml
pretrain_weights: ./pretrain_models/ResNet50_vd_ssld_pretrained/ pretrain_weights: ./pretrain_models/ResNet50_vd_ssld_pretrained/
save_res_path: ./output/det_east/predicts_east.txt save_res_path: ./output/det_east/predicts_east.txt
......
TrainReader: TrainReader:
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
num_workers: 8 num_workers: 8
lmdb_sets_dir: ./train_data/data_lmdb_release/training/ lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
EvalReader: EvalReader:
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
......
Global: Global:
algorithm: CRNN algorithm: CRNN
use_gpu: true use_gpu: false
epoch_num: 1000 epoch_num: 1000
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/rec_CRNN save_model_dir: ./output/rec_CRNN
save_epoch_step: 300 save_epoch_step: 300
eval_batch_step: 500 eval_batch_step: 500
train_batch_size_per_card: 256 train_batch_size_per_card: 2
test_batch_size_per_card: 256 test_batch_size_per_card: 2
image_shape: [3, 32, 100] image_shape: [3, 32, 100]
max_text_length: 25 max_text_length: 25
character_type: en character_type: en
loss_type: ctc loss_type: ctc
reader_yml: ./configs/rec/rec_icdar15_reader.yml reader_yml: ./configs/rec/rec_icdar15_reader.yml
pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img: infer_img:
......
Global: Global:
algorithm: RARE algorithm: RARE
use_gpu: true use_gpu: false
epoch_num: 72 epoch_num: 72
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: output/rec_RARE save_model_dir: output/rec_RARE
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: 2000 eval_batch_step: 2000
train_batch_size_per_card: 256 train_batch_size_per_card: 2
test_batch_size_per_card: 256 test_batch_size_per_card: 2
image_shape: [3, 32, 100] image_shape: [3, 32, 100]
max_text_length: 25 max_text_length: 25
character_type: en character_type: en
......
...@@ -32,7 +32,6 @@ class TrainReader(object): ...@@ -32,7 +32,6 @@ class TrainReader(object):
self.num_workers = params['num_workers'] self.num_workers = params['num_workers']
self.label_file_path = params['label_file_path'] self.label_file_path = params['label_file_path']
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
self.drop_last = params['drop_last']
assert 'process_function' in params,\ assert 'process_function' in params,\
"absence process_function in Reader" "absence process_function in Reader"
self.process = create_module(params['process_function'])(params) self.process = create_module(params['process_function'])(params)
...@@ -62,9 +61,6 @@ class TrainReader(object): ...@@ -62,9 +61,6 @@ class TrainReader(object):
if len(batch_outs) == self.batch_size: if len(batch_outs) == self.batch_size:
yield batch_outs yield batch_outs
batch_outs = [] batch_outs = []
if not self.drop_last:
if len(batch_outs) != 0:
yield batch_outs
return batch_iter_reader return batch_iter_reader
......
...@@ -94,9 +94,9 @@ class RecModel(object): ...@@ -94,9 +94,9 @@ class RecModel(object):
logger.info( logger.info(
"WARNRNG!!!\n" "WARNRNG!!!\n"
"TPS does not support variable shape in chinese!" "TPS does not support variable shape in chinese!"
"We set default shape=[3,32,320], it may affect the inference effect" "We set img_shape to be the same , it may affect the inference effect"
) )
image_shape[-1] = 320 image_shape = deepcopy(self.image_shape)
image = fluid.data(name='image', shape=image_shape, dtype='float32') image = fluid.data(name='image', shape=image_shape, dtype='float32')
labels = None labels = None
loader = None loader = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册