未验证 提交 2e05d54a 编写于 作者: Z zhoujun 提交者: GitHub

add d2s train for slanet and v3 (#9341)

* add d2s train for slanet and v3

* fix bug
上级 623424fc
......@@ -17,6 +17,7 @@ Global:
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./checkpoints/det_db/predicts_db.txt
distributed: true
d2s_train_image_shape: [3, -1, -1]
Architecture:
name: DistillationModel
......
......@@ -12,6 +12,7 @@ Global:
use_visualdl: False
seed: 2022
infer_img: ppstructure/docs/kie/input/zh_val_42.jpg
d2s_train_image_shape: [3, 224, 224]
# if you want to predict using the groundtruth ocr info,
# you can use the following config
# infer_img: train_data/XFUND/zh_val/val.json
......
......@@ -19,6 +19,7 @@ Global:
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt
d2s_train_image_shape: [3, 48, -1]
Optimizer:
......
......@@ -21,6 +21,7 @@ Global:
infer_mode: False
use_sync_bn: True
save_res_path: 'output/infer'
d2s_train_image_shape: [3, -1, -1]
Optimizer:
name: Adam
......
......@@ -17,6 +17,7 @@ Global:
infer_mode: false
max_text_length: &max_text_length 500
box_format: &box_format 'xywh' # 'xywh', 'xyxy', 'xyxyxyxy'
d2s_train_image_shape: [3, 480, 480]
Optimizer:
......
......@@ -38,9 +38,9 @@ def build_model(config):
def apply_to_static(model, config, logger):
if config["Global"].get("to_static", False) is not True:
return model
assert "image_shape" in config[
"Global"], "image_shape must be assigned for static training mode..."
supported_list = ["DB", "SVTR_LCNet", "TableMaster"]
assert "d2s_train_image_shape" in config[
"Global"], "d2s_train_image_shape must be assigned for static training mode..."
supported_list = ["DB", "SVTR_LCNet", "TableMaster", "LayoutXLM", "SLANet"]
if config["Architecture"]["algorithm"] in ["Distillation"]:
algo = list(config["Architecture"]["Models"].values())[0]["algorithm"]
else:
......@@ -49,7 +49,7 @@ def apply_to_static(model, config, logger):
specs = [
InputSpec(
[None] + config["Global"]["image_shape"], dtype='float32')
[None] + config["Global"]["d2s_train_image_shape"], dtype='float32')
]
if algo == "SVTR_LCNet":
......@@ -62,7 +62,7 @@ def apply_to_static(model, config, logger):
[None], dtype='int64'), InputSpec(
[None], dtype='float64')
])
if algo == "TableMaster":
elif algo == "TableMaster":
specs.append(
[
InputSpec(
......@@ -76,6 +76,34 @@ def apply_to_static(model, config, logger):
InputSpec(
[None, 6], dtype='float32'),
])
elif algo == "LayoutXLM":
specs = [[
InputSpec(
shape=[None, 512], dtype="int64"), # input_ids
InputSpec(
shape=[None, 512, 4], dtype="int64"), # bbox
InputSpec(
shape=[None, 512], dtype="int64"), # attention_mask
InputSpec(
shape=[None, 512], dtype="int64"), # token_type_ids
InputSpec(
shape=[None, 3, 224, 224], dtype="float32"), # image
InputSpec(
shape=[None, 512], dtype="int64"), # label
]]
elif algo == "SLANet":
specs.append([
InputSpec(
[None, config["Global"]["max_text_length"] + 2], dtype='int64'),
InputSpec(
[None, config["Global"]["max_text_length"] + 2, 4],
dtype='float32'),
InputSpec(
[None, config["Global"]["max_text_length"] + 2, 1],
dtype='float32'),
InputSpec(
[None, 6], dtype='float64'),
])
model = to_static(model, input_spec=specs)
logger.info("Successfully to apply @to_static with specs: {}".format(specs))
return model
......@@ -20,6 +20,8 @@ from tqdm import tqdm
from ppocr.utils.logging import get_logger
MODELS_DIR = os.path.expanduser("~/.paddleocr/models/")
def download_with_progressbar(url, save_path):
logger = get_logger()
......
......@@ -17,7 +17,7 @@ norm_train:tools/train.py -c configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_cml.yml -o
pact_train:null
fpgm_train:null
distill_train:null
null:null
to_static_train:Global.to_static=true
null:null
##
===========================eval_params===========================
......
......@@ -19,6 +19,7 @@ Global:
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt
d2s_train_image_shape: [3, 48, -1]
Optimizer:
......
......@@ -17,7 +17,7 @@ norm_train:tools/train.py -c test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_d
pact_train:null
fpgm_train:null
distill_train:null
null:null
to_static_train:Global.to_static=true
null:null
##
===========================eval_params===========================
......
......@@ -21,6 +21,7 @@ Global:
infer_mode: False
use_sync_bn: True
save_res_path: 'output/infer'
d2s_train_image_shape: [3, -1, -1]
Optimizer:
name: Adam
......
......@@ -17,7 +17,7 @@ norm_train:tools/train.py -c test_tipc/configs/slanet/SLANet.yml -o Global.print
pact_train:null
fpgm_train:null
distill_train:null
null:null
to_static_train:Global.to_static=true
null:null
##
===========================eval_params===========================
......
......@@ -16,7 +16,7 @@ Global:
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false
max_text_length: 500
image_shape: [3, 480, 480]
d2s_train_image_shape: [3, 480, 480]
Optimizer:
......
......@@ -17,7 +17,7 @@ norm_train:tools/train.py -c ./configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_z
pact_train:null
fpgm_train:null
distill_train:null
null:null
to_static_train:Global.to_static=true
null:null
##
===========================eval_params===========================
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册