未验证 提交 2e05d54a 编写于 作者: Z zhoujun 提交者: GitHub

add d2s train for slanet and v3 (#9341)

* add d2s train for slanet and v3

* fix bug
上级 623424fc
...@@ -17,6 +17,7 @@ Global: ...@@ -17,6 +17,7 @@ Global:
infer_img: doc/imgs_en/img_10.jpg infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt save_res_path: ./checkpoints/det_db/predicts_db.txt
distributed: true distributed: true
d2s_train_image_shape: [3, -1, -1]
Architecture: Architecture:
name: DistillationModel name: DistillationModel
......
...@@ -12,6 +12,7 @@ Global: ...@@ -12,6 +12,7 @@ Global:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: ppstructure/docs/kie/input/zh_val_42.jpg infer_img: ppstructure/docs/kie/input/zh_val_42.jpg
d2s_train_image_shape: [3, 224, 224]
# if you want to predict using the groundtruth ocr info, # if you want to predict using the groundtruth ocr info,
# you can use the following config # you can use the following config
# infer_img: train_data/XFUND/zh_val/val.json # infer_img: train_data/XFUND/zh_val/val.json
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
use_space_char: true use_space_char: true
distributed: true distributed: true
save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt
d2s_train_image_shape: [3, 48, -1]
Optimizer: Optimizer:
......
...@@ -21,6 +21,7 @@ Global: ...@@ -21,6 +21,7 @@ Global:
infer_mode: False infer_mode: False
use_sync_bn: True use_sync_bn: True
save_res_path: 'output/infer' save_res_path: 'output/infer'
d2s_train_image_shape: [3, -1, -1]
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -17,6 +17,7 @@ Global: ...@@ -17,6 +17,7 @@ Global:
infer_mode: false infer_mode: false
max_text_length: &max_text_length 500 max_text_length: &max_text_length 500
box_format: &box_format 'xywh' # 'xywh', 'xyxy', 'xyxyxyxy' box_format: &box_format 'xywh' # 'xywh', 'xyxy', 'xyxyxyxy'
d2s_train_image_shape: [3, 480, 480]
Optimizer: Optimizer:
......
...@@ -38,9 +38,9 @@ def build_model(config): ...@@ -38,9 +38,9 @@ def build_model(config):
def apply_to_static(model, config, logger): def apply_to_static(model, config, logger):
if config["Global"].get("to_static", False) is not True: if config["Global"].get("to_static", False) is not True:
return model return model
assert "image_shape" in config[ assert "d2s_train_image_shape" in config[
"Global"], "image_shape must be assigned for static training mode..." "Global"], "d2s_train_image_shape must be assigned for static training mode..."
supported_list = ["DB", "SVTR_LCNet", "TableMaster"] supported_list = ["DB", "SVTR_LCNet", "TableMaster", "LayoutXLM", "SLANet"]
if config["Architecture"]["algorithm"] in ["Distillation"]: if config["Architecture"]["algorithm"] in ["Distillation"]:
algo = list(config["Architecture"]["Models"].values())[0]["algorithm"] algo = list(config["Architecture"]["Models"].values())[0]["algorithm"]
else: else:
...@@ -49,7 +49,7 @@ def apply_to_static(model, config, logger): ...@@ -49,7 +49,7 @@ def apply_to_static(model, config, logger):
specs = [ specs = [
InputSpec( InputSpec(
[None] + config["Global"]["image_shape"], dtype='float32') [None] + config["Global"]["d2s_train_image_shape"], dtype='float32')
] ]
if algo == "SVTR_LCNet": if algo == "SVTR_LCNet":
...@@ -62,7 +62,7 @@ def apply_to_static(model, config, logger): ...@@ -62,7 +62,7 @@ def apply_to_static(model, config, logger):
[None], dtype='int64'), InputSpec( [None], dtype='int64'), InputSpec(
[None], dtype='float64') [None], dtype='float64')
]) ])
if algo == "TableMaster": elif algo == "TableMaster":
specs.append( specs.append(
[ [
InputSpec( InputSpec(
...@@ -76,6 +76,34 @@ def apply_to_static(model, config, logger): ...@@ -76,6 +76,34 @@ def apply_to_static(model, config, logger):
InputSpec( InputSpec(
[None, 6], dtype='float32'), [None, 6], dtype='float32'),
]) ])
elif algo == "LayoutXLM":
specs = [[
InputSpec(
shape=[None, 512], dtype="int64"), # input_ids
InputSpec(
shape=[None, 512, 4], dtype="int64"), # bbox
InputSpec(
shape=[None, 512], dtype="int64"), # attention_mask
InputSpec(
shape=[None, 512], dtype="int64"), # token_type_ids
InputSpec(
shape=[None, 3, 224, 224], dtype="float32"), # image
InputSpec(
shape=[None, 512], dtype="int64"), # label
]]
elif algo == "SLANet":
specs.append([
InputSpec(
[None, config["Global"]["max_text_length"] + 2], dtype='int64'),
InputSpec(
[None, config["Global"]["max_text_length"] + 2, 4],
dtype='float32'),
InputSpec(
[None, config["Global"]["max_text_length"] + 2, 1],
dtype='float32'),
InputSpec(
[None, 6], dtype='float64'),
])
model = to_static(model, input_spec=specs) model = to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(specs)) logger.info("Successfully to apply @to_static with specs: {}".format(specs))
return model return model
...@@ -20,6 +20,8 @@ from tqdm import tqdm ...@@ -20,6 +20,8 @@ from tqdm import tqdm
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
MODELS_DIR = os.path.expanduser("~/.paddleocr/models/")
def download_with_progressbar(url, save_path): def download_with_progressbar(url, save_path):
logger = get_logger() logger = get_logger()
......
...@@ -17,7 +17,7 @@ norm_train:tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o ...@@ -17,7 +17,7 @@ norm_train:tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
null:null to_static_train:Global.to_static=true
null:null null:null
## ##
===========================eval_params=========================== ===========================eval_params===========================
......
...@@ -19,6 +19,7 @@ Global: ...@@ -19,6 +19,7 @@ Global:
use_space_char: true use_space_char: true
distributed: true distributed: true
save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt
d2s_train_image_shape: [3, 48, -1]
Optimizer: Optimizer:
......
...@@ -17,7 +17,7 @@ norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_d ...@@ -17,7 +17,7 @@ norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_d
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
null:null to_static_train:Global.to_static=true
null:null null:null
## ##
===========================eval_params=========================== ===========================eval_params===========================
......
...@@ -21,6 +21,7 @@ Global: ...@@ -21,6 +21,7 @@ Global:
infer_mode: False infer_mode: False
use_sync_bn: True use_sync_bn: True
save_res_path: 'output/infer' save_res_path: 'output/infer'
d2s_train_image_shape: [3, -1, -1]
Optimizer: Optimizer:
name: Adam name: Adam
......
...@@ -17,7 +17,7 @@ norm_train:tools/train.py -c test_tipc/configs/slanet/SLANet.yml -o Global.print ...@@ -17,7 +17,7 @@ norm_train:tools/train.py -c test_tipc/configs/slanet/SLANet.yml -o Global.print
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
null:null to_static_train:Global.to_static=true
null:null null:null
## ##
===========================eval_params=========================== ===========================eval_params===========================
......
...@@ -16,7 +16,7 @@ Global: ...@@ -16,7 +16,7 @@ Global:
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false infer_mode: false
max_text_length: 500 max_text_length: 500
image_shape: [3, 480, 480] d2s_train_image_shape: [3, 480, 480]
Optimizer: Optimizer:
......
...@@ -17,7 +17,7 @@ norm_train:tools/train.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_z ...@@ -17,7 +17,7 @@ norm_train:tools/train.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_z
pact_train:null pact_train:null
fpgm_train:null fpgm_train:null
distill_train:null distill_train:null
null:null to_static_train:Global.to_static=true
null:null null:null
## ##
===========================eval_params=========================== ===========================eval_params===========================
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册