From 1139a6c9ec8034e78946d2576376192eb77682be Mon Sep 17 00:00:00 2001 From: zhoujun Date: Wed, 22 Feb 2023 15:25:05 +0800 Subject: [PATCH] add table master to benchmark (#9107) * Add custom detection and recognition model usage instructions in re * update * Add custom detection and recognition model usage instructions in re * add db net for benchmark * rename benckmark to PaddleOCR_benchmark * add addict to req * rename * add table master benckmark * support tablemaster d2s train --- ppocr/data/__init__.py | 4 +- ppocr/data/lmdb_dataset.py | 85 +++++++++++++++++++ ppocr/modeling/architectures/__init__.py | 17 +++- test_tipc/benchmark_train.sh | 30 +++++-- .../configs/table_master/table_master.yml | 17 ++-- .../table_master/train_infer_python.txt | 10 ++- test_tipc/prepare.sh | 7 ++ test_tipc/test_train_inference_python.sh | 19 ++--- 8 files changed, 159 insertions(+), 30 deletions(-) diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index b602a346..4164721b 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -34,7 +34,7 @@ import paddle.distributed as dist from ppocr.data.imaug import transform, create_operators from ppocr.data.simple_dataset import SimpleDataSet -from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR +from ppocr.data.lmdb_dataset import LMDBDataSet, LMDBDataSetSR, LMDBDataSetTableMaster from ppocr.data.pgnet_dataset import PGDataSet from ppocr.data.pubtab_dataset import PubTabDataSet @@ -55,7 +55,7 @@ def build_dataloader(config, mode, device, logger, seed=None): support_dict = [ 'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet', - 'LMDBDataSetSR' + 'LMDBDataSetSR', 'LMDBDataSetTableMaster' ] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index 295643e4..f3efb604 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -18,6 +18,7 @@ import lmdb import cv2 import string import six +import pickle from PIL import Image from .imaug import transform, create_operators @@ -203,3 +204,87 @@ class LMDBDataSetSR(LMDBDataSet): if outs is None: return self.__getitem__(np.random.randint(self.__len__())) return outs + + +class LMDBDataSetTableMaster(LMDBDataSet): + def load_hierarchical_lmdb_dataset(self, data_dir): + lmdb_sets = {} + dataset_idx = 0 + env = lmdb.open( + data_dir, + max_readers=32, + readonly=True, + lock=False, + readahead=False, + meminit=False) + txn = env.begin(write=False) + num_samples = int(pickle.loads(txn.get(b"__len__"))) + lmdb_sets[dataset_idx] = {"dirpath":data_dir, "env":env, \ + "txn":txn, "num_samples":num_samples} + return lmdb_sets + + def get_img_data(self, value): + """get_img_data""" + if not value: + return None + imgdata = np.frombuffer(value, dtype='uint8') + if imgdata is None: + return None + imgori = cv2.imdecode(imgdata, 1) + if imgori is None: + return None + return imgori + + def get_lmdb_sample_info(self, txn, index): + def convert_bbox(bbox_str_list): + bbox_list = [] + for bbox_str in bbox_str_list: + bbox_list.append(int(bbox_str)) + return bbox_list + + try: + data = pickle.loads(txn.get(str(index).encode('utf8'))) + except: + return None + + # img_name, img, info_lines + file_name = data[0] + bytes = data[1] + info_lines = data[2] # raw data from TableMASTER annotation file. + # parse info_lines + raw_data = info_lines.strip().split('\n') + raw_name, text = raw_data[0], raw_data[ + 1] # don't filter the samples's length over max_seq_len. + text = text.split(',') + bbox_str_list = raw_data[2:] + bbox_split = ',' + bboxes = [{ + 'bbox': convert_bbox(bsl.strip().split(bbox_split)), + 'tokens': ['1', '2'] + } for bsl in bbox_str_list] + + # advance parse bbox + # import pdb;pdb.set_trace() + + line_info = {} + line_info['file_name'] = file_name + line_info['structure'] = text + line_info['cells'] = bboxes + line_info['image'] = bytes + return line_info + + def __getitem__(self, idx): + lmdb_idx, file_idx = self.data_idx_order_list[idx] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + data = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'], + file_idx) + if data is None: + return self.__getitem__(np.random.randint(self.__len__())) + outs = transform(data, self.ops) + if outs is None: + return self.__getitem__(np.random.randint(self.__len__())) + return outs + + def __len__(self): + return self.data_idx_order_list.shape[0] diff --git a/ppocr/modeling/architectures/__init__.py b/ppocr/modeling/architectures/__init__.py index 384ae4cc..2f8506b7 100755 --- a/ppocr/modeling/architectures/__init__.py +++ b/ppocr/modeling/architectures/__init__.py @@ -40,7 +40,7 @@ def apply_to_static(model, config, logger): return model assert "image_shape" in config[ "Global"], "image_shape must be assigned for static training mode..." - supported_list = ["DB", "SVTR_LCNet"] + supported_list = ["DB", "SVTR_LCNet", "TableMaster"] if config["Architecture"]["algorithm"] in ["Distillation"]: algo = list(config["Architecture"]["Models"].values())[0]["algorithm"] else: @@ -62,7 +62,20 @@ def apply_to_static(model, config, logger): [None], dtype='int64'), InputSpec( [None], dtype='float64') ]) - + if algo == "TableMaster": + specs.append( + [ + InputSpec( + [None, config["Global"]["max_text_length"]], dtype='int64'), + InputSpec( + [None, config["Global"]["max_text_length"], 4], + dtype='float32'), + InputSpec( + [None, config["Global"]["max_text_length"], 1], + dtype='float32'), + InputSpec( + [None, 6], dtype='float32'), + ]) model = to_static(model, input_spec=specs) logger.info("Successfully to apply @to_static with specs: {}".format(specs)) return model diff --git a/test_tipc/benchmark_train.sh b/test_tipc/benchmark_train.sh index 25fda8f9..725da8b0 100644 --- a/test_tipc/benchmark_train.sh +++ b/test_tipc/benchmark_train.sh @@ -72,6 +72,19 @@ FILENAME=$new_filename # MODE must be one of ['benchmark_train'] MODE=$2 PARAMS=$3 + +to_static="" +# parse "to_static" options and modify trainer into "to_static_trainer" +if [[ $PARAMS =~ "dynamicTostatic" ]] ;then + to_static="d2sT_" + sed -i 's/trainer:norm_train/trainer:to_static_train/g' $FILENAME + # clear PARAM contents + if [ $PARAMS = "to_static" ] ;then + PARAMS="" + fi +fi +# bash test_tipc/benchmark_train.sh test_tipc/configs/det_mv3_db_v2_0/train_benchmark.txt benchmark_train dynamic_bs8_fp32_DP_N1C8 +# bash test_tipc/benchmark_train.sh test_tipc/configs/det_mv3_db_v2_0/train_benchmark.txt benchmark_train dynamicTostatic_bs8_fp32_DP_N1C8 # bash test_tipc/benchmark_train.sh test_tipc/configs/det_mv3_db_v2_0/train_benchmark.txt benchmark_train dynamic_bs8_null_DP_N1C1 IFS=$'\n' # parser params from train_benchmark.txt @@ -140,6 +153,13 @@ if [ ! -n "$PARAMS" ] ;then fp_items_list=(${fp_items}) device_num_list=(N1C4) run_mode="DP" +elif [[ ${PARAMS} = "dynamicTostatic" ]];then + IFS="|" + model_type=$PARAMS + batch_size_list=(${batch_size}) + fp_items_list=(${fp_items}) + device_num_list=(N1C4) + run_mode="DP" else # parser params from input: modeltype_bs${bs_item}_${fp_item}_${run_mode}_${device_num} IFS="_" @@ -181,7 +201,7 @@ for batch_size in ${batch_size_list[*]}; do if [ ${#gpu_id} -le 1 ];then log_path="$SAVE_LOG/profiling_log" mkdir -p $log_path - log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_profiling" + log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}profiling" func_sed_params "$FILENAME" "${line_gpuid}" "0" # sed used gpu_id # set profile_option params tmp=`sed -i "${line_profile}s/.*/${profile_option}/" "${FILENAME}"` @@ -197,8 +217,8 @@ for batch_size in ${batch_size_list[*]}; do speed_log_path="$SAVE_LOG/index" mkdir -p $log_path mkdir -p $speed_log_path - log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_log" - speed_log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_speed" + log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}log" + speed_log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}speed" func_sed_params "$FILENAME" "${line_profile}" "null" # sed profile_id as null cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} benchmark_train > ${log_path}/${log_name} 2>&1 " echo $cmd @@ -232,8 +252,8 @@ for batch_size in ${batch_size_list[*]}; do speed_log_path="$SAVE_LOG/index" mkdir -p $log_path mkdir -p $speed_log_path - log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_log" - speed_log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_speed" + log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}log" + speed_log_name="${repo_name}_${model_name}_bs${batch_size}_${precision}_${run_mode}_${device_num}_${to_static}speed" func_sed_params "$FILENAME" "${line_gpuid}" "$gpu_id" # sed used gpu_id func_sed_params "$FILENAME" "${line_profile}" "null" # sed --profile_option as null cmd="bash test_tipc/test_train_inference_python.sh ${FILENAME} benchmark_train > ${log_path}/${log_name} 2>&1 " diff --git a/test_tipc/configs/table_master/table_master.yml b/test_tipc/configs/table_master/table_master.yml index 27f81683..f818a4c5 100644 --- a/test_tipc/configs/table_master/table_master.yml +++ b/test_tipc/configs/table_master/table_master.yml @@ -6,7 +6,7 @@ Global: save_model_dir: ./output/table_master/ save_epoch_step: 17 eval_batch_step: [0, 6259] - cal_metric_during_train: true + cal_metric_during_train: false pretrained_model: null checkpoints: save_inference_dir: output/table_master/infer @@ -16,6 +16,7 @@ Global: character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt infer_mode: false max_text_length: 500 + image_shape: [3, 480, 480] Optimizer: @@ -67,16 +68,15 @@ Metric: Train: dataset: - name: PubTabDataSet - data_dir: ./train_data/pubtabnet/train - label_file_list: [./train_data/pubtabnet/train.jsonl] + name: LMDBDataSetTableMaster + data_dir: train_data/StructureLabel_val_500/ transforms: - DecodeImage: img_mode: BGR channel_first: False - TableMasterLabelEncode: learn_empty_box: False - merge_no_span_structure: True + merge_no_span_structure: False replace_empty_cell_token: True - ResizeTableImage: max_len: 480 @@ -101,16 +101,15 @@ Train: Eval: dataset: - name: PubTabDataSet - data_dir: ./train_data/pubtabnet/test/ - label_file_list: [./train_data/pubtabnet/test.jsonl] + name: LMDBDataSetTableMaster + data_dir: train_data/StructureLabel_val_500/ transforms: - DecodeImage: img_mode: BGR channel_first: False - TableMasterLabelEncode: learn_empty_box: False - merge_no_span_structure: True + merge_no_span_structure: False replace_empty_cell_token: True - ResizeTableImage: max_len: 480 diff --git a/test_tipc/configs/table_master/train_infer_python.txt b/test_tipc/configs/table_master/train_infer_python.txt index c3a87173..a7111de4 100644 --- a/test_tipc/configs/table_master/train_infer_python.txt +++ b/test_tipc/configs/table_master/train_infer_python.txt @@ -13,11 +13,11 @@ train_infer_img_dir:./ppstructure/docs/table/table.jpg null:null ## trainer:norm_train -norm_train:tools/train.py -c test_tipc/configs/table_master/table_master.yml -o Global.print_batch_step=10 +norm_train:tools/train.py -c test_tipc/configs/table_master/table_master.yml -o Global.print_batch_step=1 pact_train:null fpgm_train:null distill_train:null -null:null +to_static_train:Global.to_static=true null:null ## ===========================eval_params=========================== @@ -51,3 +51,9 @@ null:null null:null ===========================infer_benchmark_params========================== random_infer_input:[{float32,[3,480,480]}] +===========================train_benchmark_params========================== +batch_size:10 +fp_items:fp32|fp16 +epoch:2 +--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile +flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 \ No newline at end of file diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index 17081635..da3ef905 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -138,6 +138,13 @@ if [ ${MODE} = "benchmark_train" ];then cd ../ fi + if [ ${model_name} == "table_master" ];then + wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar --no-check-certificate + cd ./pretrain_models/ && tar xf table_structure_tablemaster_train.tar && cd ../ + wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/StructureLabel_val_500.tar --no-check-certificate + cd ./train_data/ && tar xf StructureLabel_val_500.tar + cd ../ + fi fi if [ ${MODE} = "lite_train_lite_infer" ];then diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh index e182fa57..04ba8adf 100644 --- a/test_tipc/test_train_inference_python.sh +++ b/test_tipc/test_train_inference_python.sh @@ -40,8 +40,8 @@ fpgm_key=$(func_parser_key "${lines[17]}") fpgm_trainer=$(func_parser_value "${lines[17]}") distill_key=$(func_parser_key "${lines[18]}") distill_trainer=$(func_parser_value "${lines[18]}") -trainer_key1=$(func_parser_key "${lines[19]}") -trainer_value1=$(func_parser_value "${lines[19]}") +to_static_key=$(func_parser_key "${lines[19]}") +to_static_value=$(func_parser_value "${lines[19]}") trainer_key2=$(func_parser_key "${lines[20]}") trainer_value2=$(func_parser_value "${lines[20]}") @@ -253,9 +253,9 @@ else elif [ ${trainer} = "${distill_key}" ]; then run_train=${distill_trainer} run_export=${distill_export} - elif [ ${trainer} = ${trainer_key1} ]; then - run_train=${trainer_value1} - run_export=${export_value1} + elif [ ${trainer} = "${to_static_key}" ]; then + run_train="${norm_trainer} ${to_static_value}" + run_export=${norm_export} elif [[ ${trainer} = ${trainer_key2} ]]; then run_train=${trainer_value2} run_export=${export_value2} @@ -289,11 +289,11 @@ else set_save_model=$(func_set_params "${save_model_key}" "${save_log}") if [ ${#gpu} -le 2 ];then # train with cpu or single gpu - cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_train_params1} ${set_amp_config} " + cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_amp_config} ${set_train_params1}" elif [ ${#ips} -le 15 ];then # train with multi-gpu - cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_train_params1} ${set_amp_config}" + cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_batchsize} ${set_amp_config} ${set_train_params1}" else # train with multi-machine - cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_batchsize} ${set_train_params1} ${set_amp_config}" + cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_batchsize} ${set_amp_config} ${set_train_params1}" fi # run train eval $cmd @@ -337,5 +337,4 @@ else done # done with: for trainer in ${trainer_list[*]}; do done # done with: for autocast in ${autocast_list[*]}; do done # done with: for gpu in ${gpu_list[*]}; do -fi # end if [ ${MODE} = "infer" ]; then - +fi # end if [ ${MODE} = "infer" ]; then \ No newline at end of file -- GitLab