diff --git a/doc/doc_ch/detection.md b/doc/doc_ch/detection.md
index 2cf0732219ac9cd2309ae24896d7de1499986461..eba5213501ad82b198438af1ae69dd8c7ad1071e 100644
--- a/doc/doc_ch/detection.md
+++ b/doc/doc_ch/detection.md
@@ -65,7 +65,7 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/
```
-上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。
+上述指令中,通过-c 选择训练使用configs/det/det_mv3_db.yml配置文件。
有关配置文件的详细解释,请参考[链接](./config.md)。
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
diff --git a/doc/doc_en/PP-OCRv3_introduction_en.md b/doc/doc_en/PP-OCRv3_introduction_en.md
index 481e0b8174b1e5ebce84eb1745c49dccd2c565f5..815ad9b0e5a7ff2dec36ceaef995212d122a9f89 100644
--- a/doc/doc_en/PP-OCRv3_introduction_en.md
+++ b/doc/doc_en/PP-OCRv3_introduction_en.md
@@ -55,10 +55,11 @@ The ablation experiments are as follows:
|ID|Strategy|Model Size|Hmean|The Inference Time(cpu + mkldnn)|
|-|-|-|-|-|
-|baseline teacher|DB-R50|99M|83.5%|260ms|
+|baseline teacher|PP-OCR server|49M|83.2%|171ms|
|teacher1|DB-R50-LK-PAN|124M|85.0%|396ms|
|teacher2|DB-R50-LK-PAN-DML|124M|86.0%|396ms|
|baseline student|PP-OCRv2|3M|83.2%|117ms|
+|student0|DB-MV3-RSE-FPN|3.6M|84.5%|124ms|
|student1|DB-MV3-CML(teacher2)|3M|84.3%|117ms|
|student2|DB-MV3-RSE-FPN-CML(teacher2)|3.6M|85.4%|124ms|
@@ -199,7 +200,7 @@ UDML (Unified-Deep Mutual Learning) is a strategy proposed in PP-OCRv2 which is
**(6)UIM:Unlabeled Images Mining**
-UIM (Unlabeled Images Mining) is a very simple unlabeled data mining strategy. The main idea is to use a high-precision text recognition model to predict unlabeled images to obtain pseudo-labels, and select samples with high prediction confidence as training data for training lightweight models. Using this strategy, the accuracy of the recognition model is further improved to 79.4% (+1%).
+UIM (Unlabeled Images Mining) is a very simple unlabeled data mining strategy. The main idea is to use a high-precision text recognition model to predict unlabeled images to obtain pseudo-labels, and select samples with high prediction confidence as training data for training lightweight models. Using this strategy, the accuracy of the recognition model is further improved to 79.4% (+1%). In practice, we use the full data set to train the high-precision SVTR_Tiny model (acc=82.5%) for data mining. [SVTR_Tiny model download and tutorial](../../applications/高精度中文识别模型.md).

diff --git a/doc/doc_en/detection_en.md b/doc/doc_en/detection_en.md
index f85bf585cb66332d90de8d66ed315cb04ece7636..c215e1a46636a84d372245097b460c095e9cb7fd 100644
--- a/doc/doc_en/detection_en.md
+++ b/doc/doc_en/detection_en.md
@@ -51,7 +51,7 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml \
-o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
```
-In the above instruction, use `-c` to select the training to use the `configs/det/det_db_mv3.yml` configuration file.
+In the above instruction, use `-c` to select the training to use the `configs/det/det_mv3_db.yml` configuration file.
For a detailed explanation of the configuration file, please refer to [config](./config_en.md).
You can also use `-o` to change the training parameters without modifying the yml file. For example, adjust the training learning rate to 0.0001
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 0698696a080017ca65679aab60ee7d987c90c824..1656c69529e19ee04fcb4343f28fe742dabb83b0 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -26,6 +26,7 @@ import copy
from random import sample
from ppocr.utils.logging import get_logger
+from ppocr.data.imaug.vqa.augment import order_by_tbyx
class ClsLabelEncode(object):
@@ -873,6 +874,7 @@ class VQATokenLabelEncode(object):
add_special_ids=False,
algorithm='LayoutXLM',
use_textline_bbox_info=True,
+ order_method=None,
infer_mode=False,
ocr_engine=None,
**kwargs):
@@ -902,6 +904,8 @@ class VQATokenLabelEncode(object):
self.infer_mode = infer_mode
self.ocr_engine = ocr_engine
self.use_textline_bbox_info = use_textline_bbox_info
+ self.order_method = order_method
+ assert self.order_method in [None, "tb-yx"]
def split_bbox(self, bbox, text, tokenizer):
words = text.split()
@@ -941,6 +945,14 @@ class VQATokenLabelEncode(object):
# load bbox and label info
ocr_info = self._load_ocr_info(data)
+ for idx in range(len(ocr_info)):
+ if "bbox" not in ocr_info[idx]:
+ ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx][
+ "points"])
+
+ if self.order_method == "tb-yx":
+ ocr_info = order_by_tbyx(ocr_info)
+
# for re
train_re = self.contains_re and not self.infer_mode
if train_re:
@@ -980,7 +992,10 @@ class VQATokenLabelEncode(object):
info["bbox"] = self.trans_poly_to_bbox(info["points"])
encode_res = self.tokenizer.encode(
- text, pad_to_max_seq_len=False, return_attention_mask=True)
+ text,
+ pad_to_max_seq_len=False,
+ return_attention_mask=True,
+ return_token_type_ids=True)
if not self.add_special_ids:
# TODO: use tok.all_special_ids to remove
@@ -1052,10 +1067,10 @@ class VQATokenLabelEncode(object):
return data
def trans_poly_to_bbox(self, poly):
- x1 = np.min([p[0] for p in poly])
- x2 = np.max([p[0] for p in poly])
- y1 = np.min([p[1] for p in poly])
- y2 = np.max([p[1] for p in poly])
+ x1 = int(np.min([p[0] for p in poly]))
+ x2 = int(np.max([p[0] for p in poly]))
+ y1 = int(np.min([p[1] for p in poly]))
+ y2 = int(np.max([p[1] for p in poly]))
return [x1, y1, x2, y2]
def _load_ocr_info(self, data):
diff --git a/ppocr/data/imaug/vqa/__init__.py b/ppocr/data/imaug/vqa/__init__.py
index bde175115536a3f644750260082204fe5f10dc05..34189bcefb17a0776bd62a19c58081286882b5a5 100644
--- a/ppocr/data/imaug/vqa/__init__.py
+++ b/ppocr/data/imaug/vqa/__init__.py
@@ -13,12 +13,10 @@
# limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
-from .augment import DistortBBox
__all__ = [
'VQATokenPad',
'VQASerTokenChunk',
'VQAReTokenChunk',
'VQAReTokenRelation',
- 'DistortBBox',
]
diff --git a/ppocr/data/imaug/vqa/augment.py b/ppocr/data/imaug/vqa/augment.py
index fcdc9685e9855c3a2d8e9f6f5add270f95f15a6c..b95fcdf0f0baea481de59321a22dab283d99e693 100644
--- a/ppocr/data/imaug/vqa/augment.py
+++ b/ppocr/data/imaug/vqa/augment.py
@@ -16,22 +16,18 @@ import os
import sys
import numpy as np
import random
+from copy import deepcopy
-class DistortBBox:
- def __init__(self, prob=0.5, max_scale=1, **kwargs):
- """Random distort bbox
- """
- self.prob = prob
- self.max_scale = max_scale
-
- def __call__(self, data):
- if random.random() > self.prob:
- return data
- bbox = np.array(data['bbox'])
- rnd_scale = (np.random.rand(*bbox.shape) - 0.5) * 2 * self.max_scale
- bbox = np.round(bbox + rnd_scale).astype(bbox.dtype)
- data['bbox'] = np.clip(data['bbox'], 0, 1000)
- data['bbox'] = bbox.tolist()
- sys.stdout.flush()
- return data
+def order_by_tbyx(ocr_info):
+ res = sorted(ocr_info, key=lambda r: (r["bbox"][1], r["bbox"][0]))
+ for i in range(len(res) - 1):
+ for j in range(i, 0, -1):
+ if abs(res[j + 1]["bbox"][1] - res[j]["bbox"][1]) < 20 and \
+ (res[j + 1]["bbox"][0] < res[j]["bbox"][0]):
+ tmp = deepcopy(res[j])
+ res[j] = deepcopy(res[j + 1])
+ res[j + 1] = deepcopy(tmp)
+ else:
+ break
+ return res
diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py
index 74490791c2af0be54dab8ab30ac323790fcac657..da9faa08bc5ca35c5d65f7a7bfbbdd67192f052b 100644
--- a/ppocr/losses/basic_loss.py
+++ b/ppocr/losses/basic_loss.py
@@ -63,18 +63,21 @@ class KLJSLoss(object):
def __call__(self, p1, p2, reduction="mean"):
if self.mode.lower() == 'kl':
- loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
+ loss = paddle.multiply(p2,
+ paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
loss += paddle.multiply(
- p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
+ p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5
elif self.mode.lower() == "js":
- loss = paddle.multiply(p2, paddle.log((2*p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
+ loss = paddle.multiply(
+ p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss += paddle.multiply(
- p1, paddle.log((2*p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
+ p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss *= 0.5
else:
- raise ValueError("The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
-
+ raise ValueError(
+ "The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
+
if reduction == "mean":
loss = paddle.mean(loss, axis=[1, 2])
elif reduction == "none" or reduction is None:
@@ -154,7 +157,9 @@ class LossFromOutput(nn.Layer):
self.reduction = reduction
def forward(self, predicts, batch):
- loss = predicts[self.key]
+ loss = predicts
+ if self.key is not None and isinstance(predicts, dict):
+ loss = loss[self.key]
if self.reduction == 'mean':
loss = paddle.mean(loss)
elif self.reduction == 'sum':
diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py
index f4cdee8f90465e863b89d1e32b4a0285adb29eff..8d697d544b51899cdafeff94be2ecce067b907a2 100644
--- a/ppocr/losses/combined_loss.py
+++ b/ppocr/losses/combined_loss.py
@@ -24,6 +24,9 @@ from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationSARLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
+from .distillation_loss import DistillationVQASerTokenLayoutLMLoss, DistillationSERDMLLoss
+from .distillation_loss import DistillationLossFromOutput
+from .distillation_loss import DistillationVQADistanceLoss
class CombinedLoss(nn.Layer):
diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py
index 565b066d1334e6caa1b6b4094706265f363b66ef..87fed6235d73aef2695cd6db95662e615d52c94c 100644
--- a/ppocr/losses/distillation_loss.py
+++ b/ppocr/losses/distillation_loss.py
@@ -21,8 +21,10 @@ from .rec_ctc_loss import CTCLoss
from .rec_sar_loss import SARLoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
+from .basic_loss import LossFromOutput
from .det_db_loss import DBLoss
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
+from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
def _sum_loss(loss_dict):
@@ -322,3 +324,133 @@ class DistillationDistanceLoss(DistanceLoss):
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict
+
+
+class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss):
+ def __init__(self,
+ num_classes,
+ model_name_list=[],
+ key=None,
+ name="loss_ser"):
+ super().__init__(num_classes=num_classes)
+ self.model_name_list = model_name_list
+ self.key = key
+ self.name = name
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, model_name in enumerate(self.model_name_list):
+ out = predicts[model_name]
+ if self.key is not None:
+ out = out[self.key]
+ loss = super().forward(out, batch)
+ loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
+ return loss_dict
+
+
+class DistillationLossFromOutput(LossFromOutput):
+ def __init__(self,
+ reduction="none",
+ model_name_list=[],
+ dist_key=None,
+ key="loss",
+ name="loss_re"):
+ super().__init__(key=key, reduction=reduction)
+ self.model_name_list = model_name_list
+ self.name = name
+ self.dist_key = dist_key
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, model_name in enumerate(self.model_name_list):
+ out = predicts[model_name]
+ if self.dist_key is not None:
+ out = out[self.dist_key]
+ loss = super().forward(out, batch)
+ loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
+ return loss_dict
+
+
+class DistillationSERDMLLoss(DMLLoss):
+ """
+ """
+
+ def __init__(self,
+ act="softmax",
+ use_log=True,
+ num_classes=7,
+ model_name_pairs=[],
+ key=None,
+ name="loss_dml_ser"):
+ super().__init__(act=act, use_log=use_log)
+ assert isinstance(model_name_pairs, list)
+ self.key = key
+ self.name = name
+ self.num_classes = num_classes
+ self.model_name_pairs = model_name_pairs
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, pair in enumerate(self.model_name_pairs):
+ out1 = predicts[pair[0]]
+ out2 = predicts[pair[1]]
+ if self.key is not None:
+ out1 = out1[self.key]
+ out2 = out2[self.key]
+ out1 = out1.reshape([-1, out1.shape[-1]])
+ out2 = out2.reshape([-1, out2.shape[-1]])
+
+ attention_mask = batch[2]
+ if attention_mask is not None:
+ active_output = attention_mask.reshape([-1, ]) == 1
+ out1 = out1[active_output]
+ out2 = out2[active_output]
+
+ loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1,
+ out2)
+
+ return loss_dict
+
+
+class DistillationVQADistanceLoss(DistanceLoss):
+ def __init__(self,
+ mode="l2",
+ model_name_pairs=[],
+ key=None,
+ name="loss_distance",
+ **kargs):
+ super().__init__(mode=mode, **kargs)
+ assert isinstance(model_name_pairs, list)
+ self.key = key
+ self.model_name_pairs = model_name_pairs
+ self.name = name + "_l2"
+
+ def forward(self, predicts, batch):
+ loss_dict = dict()
+ for idx, pair in enumerate(self.model_name_pairs):
+ out1 = predicts[pair[0]]
+ out2 = predicts[pair[1]]
+ attention_mask = batch[2]
+ if self.key is not None:
+ out1 = out1[self.key]
+ out2 = out2[self.key]
+ if attention_mask is not None:
+ max_len = attention_mask.shape[-1]
+ out1 = out1[:, :max_len]
+ out2 = out2[:, :max_len]
+ out1 = out1.reshape([-1, out1.shape[-1]])
+ out2 = out2.reshape([-1, out2.shape[-1]])
+ if attention_mask is not None:
+ active_output = attention_mask.reshape([-1, ]) == 1
+ out1 = out1[active_output]
+ out2 = out2[active_output]
+
+ loss = super().forward(out1, out2)
+ if isinstance(loss, dict):
+ for key in loss:
+ loss_dict["{}_{}nohu_{}".format(self.name, key,
+ idx)] = loss[key]
+ else:
+ loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
+ idx)] = loss
+ return loss_dict
diff --git a/ppocr/losses/vqa_token_layoutlm_loss.py b/ppocr/losses/vqa_token_layoutlm_loss.py
index f9cd4634731a26dd990d6ffac3d8defc8cdf7e97..5d564c0e26f8fea1359ba3aa489359873a033cb9 100755
--- a/ppocr/losses/vqa_token_layoutlm_loss.py
+++ b/ppocr/losses/vqa_token_layoutlm_loss.py
@@ -17,26 +17,30 @@ from __future__ import division
from __future__ import print_function
from paddle import nn
+from ppocr.losses.basic_loss import DMLLoss
class VQASerTokenLayoutLMLoss(nn.Layer):
- def __init__(self, num_classes):
+ def __init__(self, num_classes, key=None):
super().__init__()
self.loss_class = nn.CrossEntropyLoss()
self.num_classes = num_classes
self.ignore_index = self.loss_class.ignore_index
+ self.key = key
def forward(self, predicts, batch):
+ if isinstance(predicts, dict) and self.key is not None:
+ predicts = predicts[self.key]
labels = batch[5]
attention_mask = batch[2]
if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1
- active_outputs = predicts.reshape(
+ active_output = predicts.reshape(
[-1, self.num_classes])[active_loss]
- active_labels = labels.reshape([-1, ])[active_loss]
- loss = self.loss_class(active_outputs, active_labels)
+ active_label = labels.reshape([-1, ])[active_loss]
+ loss = self.loss_class(active_output, active_label)
else:
loss = self.loss_class(
predicts.reshape([-1, self.num_classes]),
labels.reshape([-1, ]))
- return {'loss': loss}
+ return {'loss': loss}
\ No newline at end of file
diff --git a/ppocr/metrics/distillation_metric.py b/ppocr/metrics/distillation_metric.py
index c440cebdd0f96493fc33000a0d304cbe5e3f0624..e2cbc4dc07c4e7b234ca964eb9fc0259dfab6ab4 100644
--- a/ppocr/metrics/distillation_metric.py
+++ b/ppocr/metrics/distillation_metric.py
@@ -19,6 +19,8 @@ from .rec_metric import RecMetric
from .det_metric import DetMetric
from .e2e_metric import E2EMetric
from .cls_metric import ClsMetric
+from .vqa_token_ser_metric import VQASerTokenMetric
+from .vqa_token_re_metric import VQAReTokenMetric
class DistillationMetric(object):
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index c6b50d4886daa9bfd2f863c1d8fd6dbc3d1e42c0..ed2a909cb58d56ec5a67b897de1a171658228acb 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -73,28 +73,40 @@ class BaseModel(nn.Layer):
self.return_all_feats = config.get("return_all_feats", False)
def forward(self, x, data=None):
+
y = dict()
if self.use_transform:
x = self.transform(x)
x = self.backbone(x)
- y["backbone_out"] = x
+ if isinstance(x, dict):
+ y.update(x)
+ else:
+ y["backbone_out"] = x
+ final_name = "backbone_out"
if self.use_neck:
x = self.neck(x)
- y["neck_out"] = x
+ if isinstance(x, dict):
+ y.update(x)
+ else:
+ y["neck_out"] = x
+ final_name = "neck_out"
if self.use_head:
x = self.head(x, targets=data)
- # for multi head, save ctc neck out for udml
- if isinstance(x, dict) and 'ctc_neck' in x.keys():
- y["neck_out"] = x["ctc_neck"]
- y["head_out"] = x
- elif isinstance(x, dict):
- y.update(x)
- else:
- y["head_out"] = x
+ # for multi head, save ctc neck out for udml
+ if isinstance(x, dict) and 'ctc_neck' in x.keys():
+ y["neck_out"] = x["ctc_neck"]
+ y["head_out"] = x
+ elif isinstance(x, dict):
+ y.update(x)
+ else:
+ y["head_out"] = x
+ final_name = "head_out"
if self.return_all_feats:
if self.training:
return y
+ elif isinstance(x, dict):
+ return x
else:
- return {"head_out": y["head_out"]}
+ return {final_name: x}
else:
return x
diff --git a/ppocr/modeling/backbones/vqa_layoutlm.py b/ppocr/modeling/backbones/vqa_layoutlm.py
index 34dd9d10ea36758059448d96674d4d2c249d3ad0..d4ced350885bd54e6c6065cb0f21c45780c136b0 100644
--- a/ppocr/modeling/backbones/vqa_layoutlm.py
+++ b/ppocr/modeling/backbones/vqa_layoutlm.py
@@ -22,13 +22,22 @@ from paddle import nn
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction
+from paddlenlp.transformers import AutoModel
-__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
+__all__ = ["LayoutXLMForSer", "LayoutLMForSer"]
pretrained_model_dict = {
- LayoutXLMModel: 'layoutxlm-base-uncased',
- LayoutLMModel: 'layoutlm-base-uncased',
- LayoutLMv2Model: 'layoutlmv2-base-uncased'
+ LayoutXLMModel: {
+ "base": "layoutxlm-base-uncased",
+ "vi": "layoutxlm-wo-backbone-base-uncased",
+ },
+ LayoutLMModel: {
+ "base": "layoutlm-base-uncased",
+ },
+ LayoutLMv2Model: {
+ "base": "layoutlmv2-base-uncased",
+ "vi": "layoutlmv2-wo-backbone-base-uncased",
+ },
}
@@ -36,42 +45,47 @@ class NLPBaseModel(nn.Layer):
def __init__(self,
base_model_class,
model_class,
- type='ser',
+ mode="base",
+ type="ser",
pretrained=True,
checkpoints=None,
**kwargs):
super(NLPBaseModel, self).__init__()
- if checkpoints is not None:
+ if checkpoints is not None: # load the trained model
self.model = model_class.from_pretrained(checkpoints)
- elif isinstance(pretrained, (str, )) and os.path.exists(pretrained):
- self.model = model_class.from_pretrained(pretrained)
- else:
- pretrained_model_name = pretrained_model_dict[base_model_class]
+ else: # load the pretrained-model
+ pretrained_model_name = pretrained_model_dict[base_model_class][
+ mode]
if pretrained is True:
base_model = base_model_class.from_pretrained(
pretrained_model_name)
else:
- base_model = base_model_class(
- **base_model_class.pretrained_init_configuration[
- pretrained_model_name])
- if type == 'ser':
+ base_model = base_model_class.from_pretrained(pretrained)
+ if type == "ser":
self.model = model_class(
- base_model, num_classes=kwargs['num_classes'], dropout=None)
+ base_model, num_classes=kwargs["num_classes"], dropout=None)
else:
self.model = model_class(base_model, dropout=None)
self.out_channels = 1
+ self.use_visual_backbone = True
class LayoutLMForSer(NLPBaseModel):
- def __init__(self, num_classes, pretrained=True, checkpoints=None,
+ def __init__(self,
+ num_classes,
+ pretrained=True,
+ checkpoints=None,
+ mode="base",
**kwargs):
super(LayoutLMForSer, self).__init__(
LayoutLMModel,
LayoutLMForTokenClassification,
- 'ser',
+ mode,
+ "ser",
pretrained,
checkpoints,
- num_classes=num_classes)
+ num_classes=num_classes, )
+ self.use_visual_backbone = False
def forward(self, x):
x = self.model(
@@ -85,62 +99,92 @@ class LayoutLMForSer(NLPBaseModel):
class LayoutLMv2ForSer(NLPBaseModel):
- def __init__(self, num_classes, pretrained=True, checkpoints=None,
+ def __init__(self,
+ num_classes,
+ pretrained=True,
+ checkpoints=None,
+ mode="base",
**kwargs):
super(LayoutLMv2ForSer, self).__init__(
LayoutLMv2Model,
LayoutLMv2ForTokenClassification,
- 'ser',
+ mode,
+ "ser",
pretrained,
checkpoints,
num_classes=num_classes)
+ self.use_visual_backbone = True
+ if hasattr(self.model.layoutlmv2, "use_visual_backbone"
+ ) and self.model.layoutlmv2.use_visual_backbone is False:
+ self.use_visual_backbone = False
def forward(self, x):
+ if self.use_visual_backbone is True:
+ image = x[4]
+ else:
+ image = None
x = self.model(
input_ids=x[0],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
- image=x[4],
+ image=image,
position_ids=None,
head_mask=None,
labels=None)
- if not self.training:
+ if self.training:
+ res = {"backbone_out": x[0]}
+ res.update(x[1])
+ return res
+ else:
return x
- return x[0]
class LayoutXLMForSer(NLPBaseModel):
- def __init__(self, num_classes, pretrained=True, checkpoints=None,
+ def __init__(self,
+ num_classes,
+ pretrained=True,
+ checkpoints=None,
+ mode="base",
**kwargs):
super(LayoutXLMForSer, self).__init__(
LayoutXLMModel,
LayoutXLMForTokenClassification,
- 'ser',
+ mode,
+ "ser",
pretrained,
checkpoints,
num_classes=num_classes)
+ self.use_visual_backbone = True
def forward(self, x):
+ if self.use_visual_backbone is True:
+ image = x[4]
+ else:
+ image = None
x = self.model(
input_ids=x[0],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
- image=x[4],
+ image=image,
position_ids=None,
head_mask=None,
labels=None)
- if not self.training:
+ if self.training:
+ res = {"backbone_out": x[0]}
+ res.update(x[1])
+ return res
+ else:
return x
- return x[0]
class LayoutLMv2ForRe(NLPBaseModel):
- def __init__(self, pretrained=True, checkpoints=None, **kwargs):
- super(LayoutLMv2ForRe, self).__init__(LayoutLMv2Model,
- LayoutLMv2ForRelationExtraction,
- 're', pretrained, checkpoints)
+ def __init__(self, pretrained=True, checkpoints=None, mode="base",
+ **kwargs):
+ super(LayoutLMv2ForRe, self).__init__(
+ LayoutLMv2Model, LayoutLMv2ForRelationExtraction, mode, "re",
+ pretrained, checkpoints)
def forward(self, x):
x = self.model(
@@ -158,18 +202,27 @@ class LayoutLMv2ForRe(NLPBaseModel):
class LayoutXLMForRe(NLPBaseModel):
- def __init__(self, pretrained=True, checkpoints=None, **kwargs):
- super(LayoutXLMForRe, self).__init__(LayoutXLMModel,
- LayoutXLMForRelationExtraction,
- 're', pretrained, checkpoints)
+ def __init__(self, pretrained=True, checkpoints=None, mode="base",
+ **kwargs):
+ super(LayoutXLMForRe, self).__init__(
+ LayoutXLMModel, LayoutXLMForRelationExtraction, mode, "re",
+ pretrained, checkpoints)
+ self.use_visual_backbone = True
+ if hasattr(self.model.layoutxlm, "use_visual_backbone"
+ ) and self.model.layoutxlm.use_visual_backbone is False:
+ self.use_visual_backbone = False
def forward(self, x):
+ if self.use_visual_backbone is True:
+ image = x[4]
+ else:
+ image = None
x = self.model(
input_ids=x[0],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
- image=x[4],
+ image=image,
position_ids=None,
head_mask=None,
labels=None,
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index b3c11d84b431dae9f74f3c2f444e475c0ff49d39..7c0c7fd003a38966a24fd116d8cfd3805aed6797 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -31,8 +31,8 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
SPINLabelDecode, VLLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
-from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
-from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
+from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
+from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
@@ -45,7 +45,9 @@ def build_post_process(config, global_config=None):
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
- 'TableMasterLabelDecode', 'SPINLabelDecode', 'VLLabelDecode'
+ 'TableMasterLabelDecode', 'SPINLabelDecode',
+ 'DistillationSerPostProcess', 'DistillationRePostProcess',
+ 'VLLabelDecode'
]
if config['name'] == 'PSEPostProcess':
diff --git a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
index 1d55d13d76b496ba0a5b540ba915889ce9146a8e..96c25d9aac01066f7a3841fe61aa7b0fe05041bd 100644
--- a/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
+++ b/ppocr/postprocess/vqa_token_re_layoutlm_postprocess.py
@@ -49,3 +49,25 @@ class VQAReTokenLayoutLMPostProcess(object):
result.append((ocr_info_head, ocr_info_tail))
results.append(result)
return results
+
+
+class DistillationRePostProcess(VQAReTokenLayoutLMPostProcess):
+ """
+ DistillationRePostProcess
+ """
+
+ def __init__(self, model_name=["Student"], key=None, **kwargs):
+ super().__init__(**kwargs)
+ if not isinstance(model_name, list):
+ model_name = [model_name]
+ self.model_name = model_name
+ self.key = key
+
+ def __call__(self, preds, *args, **kwargs):
+ output = dict()
+ for name in self.model_name:
+ pred = preds[name]
+ if self.key is not None:
+ pred = pred[self.key]
+ output[name] = super().__call__(pred, *args, **kwargs)
+ return output
diff --git a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
index 8a6669f71f5ae6a7a16931e565b43355de5928d9..5541da90a05d0137628f45f72b15fd61eba1e203 100644
--- a/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
+++ b/ppocr/postprocess/vqa_token_ser_layoutlm_postprocess.py
@@ -93,3 +93,25 @@ class VQASerTokenLayoutLMPostProcess(object):
ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
results.append(ocr_info)
return results
+
+
+class DistillationSerPostProcess(VQASerTokenLayoutLMPostProcess):
+ """
+ DistillationSerPostProcess
+ """
+
+ def __init__(self, class_path, model_name=["Student"], key=None, **kwargs):
+ super().__init__(class_path, **kwargs)
+ if not isinstance(model_name, list):
+ model_name = [model_name]
+ self.model_name = model_name
+ self.key = key
+
+ def __call__(self, preds, batch=None, *args, **kwargs):
+ output = dict()
+ for name in self.model_name:
+ pred = preds[name]
+ if self.key is not None:
+ pred = pred[self.key]
+ output[name] = super().__call__(pred, batch=batch, *args, **kwargs)
+ return output
diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py
index 3647111fddaa848a75873ab689559c63dd6d4814..e77a6ce0183611569193e1996e935f4bd30400a0 100644
--- a/ppocr/utils/save_load.py
+++ b/ppocr/utils/save_load.py
@@ -53,8 +53,12 @@ def load_model(config, model, optimizer=None, model_type='det'):
checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model')
best_model_dict = {}
+ is_float16 = False
if model_type == 'vqa':
+ # NOTE: for vqa model, resume training is not supported now
+ if config["Architecture"]["algorithm"] in ["Distillation"]:
+ return best_model_dict
checkpoints = config['Architecture']['Backbone']['checkpoints']
# load vqa method metric
if checkpoints:
@@ -78,6 +82,7 @@ def load_model(config, model, optimizer=None, model_type='det'):
logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints))
+
return best_model_dict
if checkpoints:
@@ -96,6 +101,9 @@ def load_model(config, model, optimizer=None, model_type='det'):
key, params.keys()))
continue
pre_value = params[key]
+ if pre_value.dtype == paddle.float16:
+ pre_value = pre_value.astype(paddle.float32)
+ is_float16 = True
if list(value.shape) == list(pre_value.shape):
new_state_dict[key] = pre_value
else:
@@ -103,7 +111,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
"The shape of model params {} {} not matched with loaded params shape {} !".
format(key, value.shape, pre_value.shape))
model.set_state_dict(new_state_dict)
-
+ if is_float16:
+ logger.info(
+ "The parameter type is float16, which is converted to float32 when loading"
+ )
if optimizer is not None:
if os.path.exists(checkpoints + '.pdopt'):
optim_dict = paddle.load(checkpoints + '.pdopt')
@@ -122,9 +133,10 @@ def load_model(config, model, optimizer=None, model_type='det'):
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
- load_pretrained_params(model, pretrained_model)
+ is_float16 = load_pretrained_params(model, pretrained_model)
else:
logger.info('train from scratch')
+ best_model_dict['is_float16'] = is_float16
return best_model_dict
@@ -138,19 +150,28 @@ def load_pretrained_params(model, path):
params = paddle.load(path + '.pdparams')
state_dict = model.state_dict()
new_state_dict = {}
+ is_float16 = False
for k1 in params.keys():
if k1 not in state_dict.keys():
logger.warning("The pretrained params {} not in model".format(k1))
else:
+ if params[k1].dtype == paddle.float16:
+ params[k1] = params[k1].astype(paddle.float32)
+ is_float16 = True
if list(state_dict[k1].shape) == list(params[k1].shape):
new_state_dict[k1] = params[k1]
else:
logger.warning(
"The shape of model params {} {} not matched with loaded params {} {} !".
format(k1, state_dict[k1].shape, k1, params[k1].shape))
+
model.set_state_dict(new_state_dict)
+ if is_float16:
+ logger.info(
+ "The parameter type is float16, which is converted to float32 when loading"
+ )
logger.info("load pretrain successful from {}".format(path))
- return model
+ return is_float16
def save_model(model,
@@ -166,15 +187,19 @@ def save_model(model,
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
- paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
+ if config['Architecture']["model_type"] != 'vqa':
+ paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
if config['Architecture']["model_type"] != 'vqa':
paddle.save(model.state_dict(), model_prefix + '.pdparams')
metric_prefix = model_prefix
- else:
+ else: # for vqa system, we follow the save/load rules in NLP
if config['Global']['distributed']:
- model._layers.backbone.model.save_pretrained(model_prefix)
+ arch = model._layers
else:
- model.backbone.model.save_pretrained(model_prefix)
+ arch = model
+ if config["Architecture"]["algorithm"] in ["Distillation"]:
+ arch = arch.Student
+ arch.backbone.model.save_pretrained(model_prefix)
metric_prefix = os.path.join(model_prefix, 'metric')
# save metric and config
with open(metric_prefix + '.states', 'wb') as f:
diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md
index 05635265b5e5eff18429e2d595fc4195381299f5..28b794383bceccf655bdf00df5ee0c98841e2e95 100644
--- a/ppstructure/vqa/README.md
+++ b/ppstructure/vqa/README.md
@@ -216,7 +216,7 @@ Use the following command to complete the tandem prediction of `OCR + SER` based
```shell
cd ppstructure
-CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
+CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --vis_font_path=../doc/fonts/simfang.ttf --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
```
After the prediction is successful, the visualization images and results will be saved in the directory specified by the `output` field
diff --git a/ppstructure/vqa/README_ch.md b/ppstructure/vqa/README_ch.md
index b421a82d3a1cbe39f5c740bea486ec26593ab20f..f168110ed9b2e750b3b2ee6f5ab0116daebc3e77 100644
--- a/ppstructure/vqa/README_ch.md
+++ b/ppstructure/vqa/README_ch.md
@@ -215,7 +215,7 @@ python3.7 tools/export_model.py -c configs/vqa/ser/layoutxlm.yml -o Architecture
```shell
cd ppstructure
-CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
+CUDA_VISIBLE_DEVICES=0 python3.7 vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_model_dir=../output/ser/infer --ser_dict_path=../train_data/XFUND/class_list_xfun.txt --vis_font_path=../doc/fonts/simfang.ttf --image_dir=docs/vqa/input/zh_val_42.jpg --output=output
```
预测成功后,可视化图片和结果会保存在`output`字段指定的目录下
diff --git a/ppstructure/vqa/predict_vqa_token_ser.py b/ppstructure/vqa/predict_vqa_token_ser.py
index de0bbfe72d80d9a16de8b09657a98dc5285bb348..3097ebcf1640eb1e4dd65f76635f21231984b0ef 100644
--- a/ppstructure/vqa/predict_vqa_token_ser.py
+++ b/ppstructure/vqa/predict_vqa_token_ser.py
@@ -153,7 +153,7 @@ def main(args):
img_res = draw_ser_results(
image_file,
ser_res,
- font_path="../doc/fonts/simfang.ttf", )
+ font_path=args.vis_font_path, )
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
diff --git a/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml b/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml
index a1497ba8fa4790a53cd602829edf3240ff8dc51a..3eb82d42bc3f2b3ca7420d999865977bbad09e31 100644
--- a/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml
+++ b/test_tipc/configs/ch_PP-OCRv2_rec/ch_PP-OCRv2_rec_distillation.yml
@@ -114,7 +114,7 @@ Train:
name: SimpleDataSet
data_dir: ./train_data/ic15_data/
label_file_list:
- - ./train_data/ic15_data/rec_gt_train4w.txt
+ - ./train_data/ic15_data/rec_gt_train.txt
transforms:
- DecodeImage:
img_mode: BGR
diff --git a/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml b/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml
index ee884f668767ea1c96782072c729bbcc700674d1..4c8ba0a6fa4a355e9bad1665a8de82399f919740 100644
--- a/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml
+++ b/test_tipc/configs/ch_PP-OCRv3_rec/ch_PP-OCRv3_rec_distillation.yml
@@ -153,7 +153,7 @@ Train:
data_dir: ./train_data/ic15_data/
ext_op_transform_idx: 1
label_file_list:
- - ./train_data/ic15_data/rec_gt_train4w.txt
+ - ./train_data/ic15_data/rec_gt_train.txt
transforms:
- DecodeImage:
img_mode: BGR
diff --git a/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt b/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt
index 420c6592d71653377c740c703bedeb8e048cfc03..59fc1bd4160ec77edb0b781c8ffa9845c6a3d5c7 100644
--- a/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_PP-OCRv3_rec/train_infer_python.txt
@@ -52,8 +52,9 @@ null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,48,320]}]
===========================train_benchmark_params==========================
-batch_size:128
+batch_size:64
fp_items:fp32|fp16
epoch:1
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
+
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 91%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index b42ab9db362b0ba56d795096fdc58a645b425480..73f1d498550b2672078a9353d665fe825a992cec 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0
+model_name:ch_ppocr_mobile_v2_0
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
similarity index 94%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
index becad991eab2535b2df7862d0d25707ef37f08f8..00373b61eef1e3aecf6f55d945b528ceb1d83a8b 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================ch_ppocr_mobile_v2.0===========================
-model_name:ch_ppocr_mobile_v2.0
+model_name:ch_ppocr_mobile_v2_0
python:python3.7
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
infer_export:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
index 17c2fbbae2e182c4a7631cb18908180d8c019b4f..3e01ae57380f99fd2eb637b08cb00ab81fd2e966 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -1,5 +1,5 @@
===========================paddle2onnx_params===========================
-model_name:ch_ppocr_mobile_v2.0
+model_name:ch_ppocr_mobile_v2_0
python:python3.7
2onnx: paddle2onnx
--det_model_dir:./inference/ch_ppocr_mobile_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
index d18e9f11fdd2ff605cdd8f6c1bcf51ca780eb766..305882aa30f0009bb109e74d1668ded8aa00ccfb 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0
+model_name:ch_ppocr_mobile_v2_0
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 842c9340176d696c1e43e59491bdcab817f9256e..0c366b03ded2e3a7e038d21f2b482a76131e6a21 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0
+model_name:ch_ppocr_mobile_v2_0
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 88%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index 1d1c2ae283b2103c2e7282186ab1f53bec05cda3..ded332e674e76ef10a2b312783f2b1bbbbf963a6 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_python_jetson.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_infer_python_jetson.txt
similarity index 92%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_python_jetson.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_infer_python_jetson.txt
index 24bb8746ab7793dbcb4af99102a007aca8b8e16b..5f9dfa5f55e4556036f45fdef4fa8e6edf8b3eb9 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_infer_python_jetson.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_infer_python_jetson.txt
@@ -1,5 +1,5 @@
===========================infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
python:python
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer
infer_export:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_lite_cpp_arm_cpu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_lite_cpp_arm_gpu_opencl.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
similarity index 84%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
index 00473d1062615834a42e350a727f50233efd831f..8f36ad4b86fc200d558e905e89e7d3594b2f13f1 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -1,5 +1,5 @@
===========================paddle2onnx_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
python:python3.7
2onnx: paddle2onnx
--det_model_dir:./inference/ch_ppocr_mobile_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index c9dd5ad920d58f60ce36a7b489073279f23ba1b7..6dfd7e7bd02e587dd2dc895d1e206b10c17fe82f 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt
similarity index 98%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt
index 3db816cc0887eb2efc195965498174e868bfc6ec..f3aa9d0f8218a24b11e3d0d079ae79a07d3e5874 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_dcu_normal_normal_infer_python_dcu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_dcu_normal_normal_infer_python_dcu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_dcu_normal_normal_infer_python_dcu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_dcu_normal_normal_infer_python_dcu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
index 5271f78bb778f9e419da7f9bbbb6b4a6fafb305b..bf81d0baa8fa8c8ae590e5bbf33564fdf664b9c5 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
index 6b3352f741a56124eead2f71c03c783e5c81a70d..df71e907022a8f4a262eb218879578b99464327a 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_det
+model_name:ch_ppocr_mobile_v2_0_det
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_mac_cpu_normal_normal_infer_python_mac_cpu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_pact_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_pact_infer_python.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_pact_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_pact_infer_python.txt
index 04c8d0e194b687f58da1c449a6a0d8d9c1acd25e..ba880d1f9e1a1ade250c25581c8b238ec58c5e30 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_pact_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_pact_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_PACT
+model_name:ch_ppocr_mobile_v2_0_det_PACT
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_ptq_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_ptq_infer_python.txt
similarity index 93%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_ptq_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_ptq_infer_python.txt
index 2bdec848833b6cf3799370b0337fa00f185a94d5..45c4fd1ae832ab86c94bc04f9acf89ef8e95d09e 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_ptq_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_ptq_infer_python.txt
@@ -1,5 +1,5 @@
===========================kl_quant_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_KL
+model_name:ch_ppocr_mobile_v2_0_det_KL
python:python3.7
Global.pretrained_model:null
Global.save_inference_dir:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det/train_windows_gpu_normal_normal_infer_python_windows_cpu_gpu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_FPGM/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_FPGM/train_infer_python.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_FPGM/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_FPGM/train_infer_python.txt
index dae3f8053a0264611b5baca0f45839f3550fe6a4..0f6df1ac52c8b42d21ae701634337113e2efab95 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_FPGM/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_FPGM/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_FPGM
+model_name:ch_ppocr_mobile_v2_0_det_FPGM
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
index 150a8a0315b83e8c62765a4aa66429cfd0590928..2014c6dbcd0f1c4e97348452b520c0e640245fe5 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_FPGM
+model_name:ch_ppocr_mobile_v2_0_det_FPGM
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 88%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index eb2fd0a001ab506f241bed7eac75d96cf4b5d5cb..f0e58dd566ac62a598930d01d7604e37273b99fa 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_KL
+model_name:ch_ppocr_mobile_v2_0_det_KL
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_mac_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_python_mac_cpu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_mac_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_python_mac_cpu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_windows_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_python_windows_gpu_cpu.txt
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_infer_python_windows_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_infer_python_windows_gpu_cpu.txt
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
similarity index 95%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
index ab518de55ae6b157b26bf332ec3b0afcab71f97a..c5dc52583bb2ff293f6c8f2f49c55e5e1dffaefc 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_KL
+model_name:ch_ppocr_mobile_v2_0_det_KL
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 049ec784581bddcce066bc049b66f6f0ceff9eed..82d4db32ab1c49c3d433e9efdc2d91df4c434cea 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_KL
+model_name:ch_ppocr_mobile_v2_0_det_KL
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index 17723f41ab762f5316cba59c08ec719aa54f03b1..5132330597ac5b9a120f1c1a4bec7bc8fb849f38 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_PACT
+model_name:ch_ppocr_mobile_v2_0_det_PACT
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_det_pact_infer
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
similarity index 95%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
index 229f70cf353318bf9ccc81f4e5be79dbc096de25..3be53952f26bd17a6f4262b64f0a5958a27a1fab 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_PACT
+model_name:ch_ppocr_mobile_v2_0_det_PACT
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_pact_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 95%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 909d738919bed78d6db04e238818cd4fbbb75e5f..63e7f8f735afc99a60f211d9e0f2ff67c4a2a89e 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_det_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_PACT
+model_name:ch_ppocr_mobile_v2_0_det_PACT
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_pact_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 89%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index 480fb16cddfc4c2f4784cc8fa88512f063f7b2ae..332e632bd5da77398d0437c187e3abe11f36dc46 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec
+model_name:ch_ppocr_mobile_v2_0_rec
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_rec_infer/
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
index 5bab0c9e4c77edba302f6b536306816b09df9224..78b76edae1aaa353f032d2ca1dd2eb21e22183a3 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -1,5 +1,5 @@
===========================paddle2onnx_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec
+model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7
2onnx: paddle2onnx
--det_model_dir:
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index c0c5291cc480f9f34aa5dcded3eafce7feac89e3..5c60903f678778c4d10c88745a4cca12fb488c6f 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec
+model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_infer_python.txt
similarity index 98%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_infer_python.txt
index 36fdb1b91eceede0d692ff4c2680d1403ec86024..40f397948936beba0a3a4bdce9aa4a9953ec9d0f 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec
+model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
index 631118c0a9ab98c10129f12ec1c1cf2bbac46115..2f919d102b2abdd5b72642abc413a47b4cf17350 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec
+model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
index bd9c4a8df2565af73b6db24636b6dd132dac0cc2..f60e2790e5879d99e322f36489aba67cb0d2b66c 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec
+model_name:ch_ppocr_mobile_v2_0_rec
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_pact_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_pact_infer_python.txt
similarity index 90%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_pact_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_pact_infer_python.txt
index 77472fbdfb21c81bb713df175a135ccf9e652f25..9c1223f41a1edac8c71e4069be132c51e8db8e3c 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_pact_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_pact_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_PACT
+model_name:ch_ppocr_mobile_v2_0_rec_PACT
python:python3.7
gpu_list:0
Global.use_gpu:True|True
@@ -14,7 +14,7 @@ null:null
##
trainer:pact_train
norm_train:null
-pact_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
+pact_train:deploy/slim/quantization/quant.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
fpgm_train:null
distill_train:null
null:null
@@ -28,7 +28,7 @@ null:null
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:null
-quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
+quant_export:deploy/slim/quantization/export_model.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/rec_chinese_lite_train_v2.0.yml -o
fpgm_export:null
distill_export:null
export1:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_ptq_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_ptq_infer_python.txt
similarity index 84%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_ptq_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_ptq_infer_python.txt
index f63fe4c2bb6a17353ecb008d83e2bee9d38aec23..df47f328bd0a1cdca2afbae5afccc66455507ca8 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec/train_ptq_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec/train_ptq_infer_python.txt
@@ -1,10 +1,10 @@
===========================kl_quant_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_KL
+model_name:ch_ppocr_mobile_v2_0_rec_KL
python:python3.7
Global.pretrained_model:null
Global.save_inference_dir:null
infer_model:./inference/ch_ppocr_mobile_v2.0_rec_infer/
-infer_export:deploy/slim/quantization/quant_kl.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/rec_chinese_lite_train_v2.0.yml -o
+infer_export:deploy/slim/quantization/quant_kl.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/rec_chinese_lite_train_v2.0.yml -o
infer_quant:True
inference:tools/infer/predict_rec.py --rec_image_shape="3,32,320"
--use_gpu:False|True
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/rec_chinese_lite_train_v2.0.yml
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/rec_chinese_lite_train_v2.0.yml
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/train_infer_python.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/train_infer_python.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/train_infer_python.txt
index 89daceeb5f4a991699a490b51358d33240e74913..94c9503103b7dbc52c8c7aa3e4d576e3d8eb1a0a 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_FPGM
+model_name:ch_ppocr_mobile_v2_0_rec_FPGM
python:python3.7
gpu_list:0
Global.use_gpu:True|True
@@ -15,7 +15,7 @@ null:null
trainer:fpgm_train
norm_train:null
pact_train:null
-fpgm_train:deploy/slim/prune/sensitivity_anal.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./pretrain_models/ch_ppocr_mobile_v2.0_rec_train/best_accuracy
+fpgm_train:deploy/slim/prune/sensitivity_anal.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./pretrain_models/ch_ppocr_mobile_v2.0_rec_train/best_accuracy
distill_train:null
null:null
null:null
@@ -29,7 +29,7 @@ Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:null
quant_export:null
-fpgm_export:deploy/slim/prune/export_prune_model.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o
+fpgm_export:deploy/slim/prune/export_prune_model.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o
distill_export:null
export1:null
export2:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
index 7abc3e9340fe49a2b0bf0efd5e3c370817cd4e9d..71555865ad900051459bd781520f4a936909f16a 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_FPGM
+model_name:ch_ppocr_mobile_v2_0_rec_FPGM
python:python3.7
gpu_list:0
Global.use_gpu:True|True
@@ -15,7 +15,7 @@ null:null
trainer:fpgm_train
norm_train:null
pact_train:null
-fpgm_train:deploy/slim/prune/sensitivity_anal.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./pretrain_models/ch_ppocr_mobile_v2.0_rec_train/best_accuracy
+fpgm_train:deploy/slim/prune/sensitivity_anal.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./pretrain_models/ch_ppocr_mobile_v2.0_rec_train/best_accuracy
distill_train:null
null:null
null:null
@@ -29,7 +29,7 @@ Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:null
quant_export:null
-fpgm_export:deploy/slim/prune/export_prune_model.py -c test_tipc/configs/ch_ppocr_mobile_v2.0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o
+fpgm_export:deploy/slim/prune/export_prune_model.py -c test_tipc/configs/ch_ppocr_mobile_v2_0_rec_FPGM/rec_chinese_lite_train_v2.0.yml -o
distill_export:null
export1:null
export2:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 89%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index adf06257a772cfd16d4109497c6e6ef7c3f8af8b..ef4c93fcdecb2e1e3adc832089efa6de50c5484c 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_KL
+model_name:ch_ppocr_mobile_v2_0_rec_KL
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_rec_klquant_infer
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
similarity index 95%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
index d9de1cc19a729485845601fe929bf57d74002641..d904e22a7cc1f68ab8bccafa86ec11b92c67c244 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_KL
+model_name:ch_ppocr_mobile_v2_0_rec_KL
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_klquant_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 948e3dceb3ef7e3f2199be5e417cfc5fc763d975..de4f7ed2c1d5cd457917cb3668f181e2366bb78f 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_KL
+model_name:ch_ppocr_mobile_v2_0_rec_KL
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/rec_chinese_lite_train_v2.0.yml b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/rec_chinese_lite_train_v2.0.yml
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_KL/rec_chinese_lite_train_v2.0.yml
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_KL/rec_chinese_lite_train_v2.0.yml
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 89%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index ba2df90f75d2c70e043c45ea19d681aabc2b6fb2..74ca7b50b8b2ac8efa00d3a1cfa5e60b97cf99ba 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_PACT
+model_name:ch_ppocr_mobile_v2_0_rec_PACT
use_opencv:True
infer_model:./inference/ch_ppocr_mobile_v2.0_rec_pact_infer
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
similarity index 95%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
index 1a49a10f9b9d4e32916dd35bae3380e2ca5bebb9..5a3047448b5264e9ab771249c7b113e32943c5f1 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_det_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_det_PACT
+model_name:ch_ppocr_mobile_v2_0_rec_PACT
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_mobile_v2.0_det_pact_infer/
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 95%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index f123f365432ab68f2484cc11dd9ef94c8a60ea8e..5871199bc020fb231d3f485d312339f634d5ccd3 100644
--- a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_mobile_v2.0_rec_PACT
+model_name:ch_ppocr_mobile_v2_0_rec_PACT
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:null
diff --git a/test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml b/test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/rec_chinese_lite_train_v2.0.yml
similarity index 100%
rename from test_tipc/configs/ch_ppocr_mobile_v2.0_rec_PACT/rec_chinese_lite_train_v2.0.yml
rename to test_tipc/configs/ch_ppocr_mobile_v2_0_rec_PACT/rec_chinese_lite_train_v2.0.yml
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 91%
rename from test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index 7c980b2baeef7161a93dea360089b333f2003a31..ba8646fd9d73c36a205c2e68ed35a2b777a47cfb 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_server_v2.0
+model_name:ch_ppocr_server_v2_0
use_opencv:True
infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
similarity index 94%
rename from test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
index b20596f7a1db6da04307a7e527ef596477d237d3..53f8ab0e746df87205197179ca3367d32eb56a6d 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================ch_ppocr_server_v2.0===========================
-model_name:ch_ppocr_server_v2.0
+model_name:ch_ppocr_server_v2_0
python:python3.7
infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
infer_export:null
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
similarity index 85%
rename from test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
index e478896a54957481a3ce4c485ac02cd7979233dc..9e2cf191f3c24523279ee53006a00b0be9a83297 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -1,5 +1,5 @@
===========================paddle2onnx_params===========================
-model_name:ch_ppocr_server_v2.0
+model_name:ch_ppocr_server_v2_0
python:python3.7
2onnx: paddle2onnx
--det_model_dir:./inference/ch_ppocr_server_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
index bbfec44dbab08dcfb932a922797448e541ea385b..55b27e04a20de07d8baf71046253601c658799ea 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_serving_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_server_v2.0
+model_name:ch_ppocr_server_v2_0
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_server_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 97%
rename from test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 8853e709d40a0fba6bedd7ce582425e39b9076ed..21b8c9a082116c58e50c58e37a5a6842a1839b9b 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_server_v2.0
+model_name:ch_ppocr_server_v2_0
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_server_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml b/test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml
similarity index 100%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 88%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index 69ae939e2b6cab5e07bc4e401a83c66324754223..4a30affd07748cd87f4c1982d8555f7b5a0a0b54 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_server_v2.0_det
+model_name:ch_ppocr_server_v2_0_det
use_opencv:True
infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
similarity index 82%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
index c8bebf54f2ed2627cce9a22013d1566eb7a7b6ef..b7dd6e22b90b0cfaeeac75a7f090e2a199c94831 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -1,5 +1,5 @@
===========================paddle2onnx_params===========================
-model_name:ch_ppocr_server_v2.0_det
+model_name:ch_ppocr_server_v2_0_det
python:python3.7
2onnx: paddle2onnx
--det_model_dir:./inference/ch_ppocr_server_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 018dd1a227064479ebd60570113b122b035e7704..4d4f0679bf6cb77b6bdb4fa7494677e933554911 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_det/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_server_v2.0_det
+model_name:ch_ppocr_server_v2_0_det
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:./inference/ch_ppocr_server_v2.0_det_infer/
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt b/test_tipc/configs/ch_ppocr_server_v2_0_det/train_infer_python.txt
similarity index 92%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/train_infer_python.txt
index 7b90a4078a0c30f9d5ecab60c82acbd4052821ea..90ed29f4303994995ff604d321d9643dcc8c46c1 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_det/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_server_v2.0_det
+model_name:ch_ppocr_server_v2_0_det
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
quant_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
quant_export:null
fpgm_export:null
distill_export:null
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
similarity index 90%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
index 12388d967755c54a46efdb915ef047896dddaef7..f398078fc5d4059f138385c583c075bc10c6ccc3 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_det/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_server_v2.0_det
+model_name:ch_ppocr_server_v2_0_det
python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
quant_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
quant_export:null
fpgm_export:null
distill_export:null
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
similarity index 90%
rename from test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
index 93ed14cb600229e744167f26573cba406880db8e..7a2d0a53c18a797318ed4ebbae9065ed2221b439 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_det/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_server_v2.0_det
+model_name:ch_ppocr_server_v2_0_det
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
quant_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_det/det_r50_vd_db.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_det/det_r50_vd_db.yml -o
quant_export:null
fpgm_export:null
distill_export:null
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
similarity index 89%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
index cbec272cce544e332fd908d4946321a15543fcae..3f3905516abcd6c7bb517bc0df50af331335dabe 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_infer_cpp_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================cpp_infer_params===========================
-model_name:ch_ppocr_server_v2.0_rec
+model_name:ch_ppocr_server_v2_0_rec
use_opencv:True
infer_model:./inference/ch_ppocr_server_v2.0_rec_infer/
infer_quant:False
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
index 462f6090d987ac2c58656136e896e71bcdc3bee1..89b9661003d1a2ee3d0987f50a9278fc96391988 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt
@@ -1,5 +1,5 @@
===========================paddle2onnx_params===========================
-model_name:ch_ppocr_server_v2.0_rec
+model_name:ch_ppocr_server_v2_0_rec
python:python3.7
2onnx: paddle2onnx
--det_model_dir:
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
similarity index 96%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
index 7f456320b687549fbcd6d4f0be7a1b4a2969684a..4133e961cdaf4acb6dbf7421f38147e996e66401 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_rec/model_linux_gpu_normal_normal_serving_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================serving_params===========================
-model_name:ch_ppocr_server_v2.0_rec
+model_name:ch_ppocr_server_v2_0_rec
python:python3.7
trans_model:-m paddle_serving_client.convert
--det_dirname:null
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml b/test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_infer_python.txt
similarity index 88%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/train_infer_python.txt
index 9fc117d67c6c2c048b2c8797bc07be8c93b0d519..b9a1ae4984c30a08d75b73b884ceb97658eb11c7 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_infer_python.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_server_v2.0_rec
+model_name:ch_ppocr_server_v2_0_rec
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/ch_ppocr_server_v2.0_rec_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py
--use_gpu:True|False
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
index 9884ab247b80de4ca700bf084cea4faa89c86396..d5f57aef9f9834fb1e7d620ee9e23cedbc83a566 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_server_v2.0_rec
+model_name:ch_ppocr_server_v2_0_rec
python:python3.7
gpu_list:192.168.0.1,192.168.0.2;0,1
Global.use_gpu:True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/ch_ppocr_server_v2.0_rec_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py
--use_gpu:False
diff --git a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
similarity index 87%
rename from test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
rename to test_tipc/configs/ch_ppocr_server_v2_0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
index 63ddaa4a8b2dcb19823034ee85af14b248b109b2..20eb10b8e6a0559119a7f810c3e3aed4458e696f 100644
--- a/test_tipc/configs/ch_ppocr_server_v2.0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
+++ b/test_tipc/configs/ch_ppocr_server_v2_0_rec/train_linux_gpu_normal_amp_infer_python_linux_gpu_cpu.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:ch_ppocr_server_v2.0_rec
+model_name:ch_ppocr_server_v2_0_rec
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/ch_ppocr_server_v2.0_rec_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2.0_rec/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/ch_ppocr_server_v2_0_rec/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py
--use_gpu:True|False
diff --git a/test_tipc/configs/det_mv3_east_v2.0/det_mv3_east.yml b/test_tipc/configs/det_mv3_east_v2_0/det_mv3_east.yml
similarity index 100%
rename from test_tipc/configs/det_mv3_east_v2.0/det_mv3_east.yml
rename to test_tipc/configs/det_mv3_east_v2_0/det_mv3_east.yml
diff --git a/test_tipc/configs/det_mv3_east_v2.0/train_infer_python.txt b/test_tipc/configs/det_mv3_east_v2_0/train_infer_python.txt
similarity index 91%
rename from test_tipc/configs/det_mv3_east_v2.0/train_infer_python.txt
rename to test_tipc/configs/det_mv3_east_v2_0/train_infer_python.txt
index 1ec1597a4d50ba1c41cfb076fa7431f170e183bf..9c6d9660d545276bf7f1f2d650dd90354709251b 100644
--- a/test_tipc/configs/det_mv3_east_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_mv3_east_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:det_mv3_east_v2.0
+model_name:det_mv3_east_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/det_mv3_east_v2.0/det_mv3_east.yml -o
+norm_train:tools/train.py -c test_tipc/configs/det_mv3_east_v2_0/det_mv3_east.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/det_mv3_east_v2.0/det_mv3_east.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/det_mv3_east_v2_0/det_mv3_east.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/det_mv3_east_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/det_mv3_east_v2.0/det_mv3_east.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/det_mv3_east_v2_0/det_mv3_east.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
diff --git a/test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml b/test_tipc/configs/det_mv3_pse_v2_0/det_mv3_pse.yml
similarity index 100%
rename from test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml
rename to test_tipc/configs/det_mv3_pse_v2_0/det_mv3_pse.yml
diff --git a/test_tipc/configs/det_mv3_pse_v2.0/train_infer_python.txt b/test_tipc/configs/det_mv3_pse_v2_0/train_infer_python.txt
similarity index 91%
rename from test_tipc/configs/det_mv3_pse_v2.0/train_infer_python.txt
rename to test_tipc/configs/det_mv3_pse_v2_0/train_infer_python.txt
index daeec69f84a766e1d6cd2f8906772c27f5f8d048..525fdc7d4bc1f037e8b133df39d0e86c173de95a 100644
--- a/test_tipc/configs/det_mv3_pse_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_mv3_pse_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:det_mv3_pse_v2.0
+model_name:det_mv3_pse_v2_0
python:python3.7
gpu_list:0
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml -o
+norm_train:tools/train.py -c test_tipc/configs/det_mv3_pse_v2_0/det_mv3_pse.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/det_mv3_pse_v2_0/det_mv3_pse.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/det_mv3_pse_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/det_mv3_pse_v2.0/det_mv3_pse.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/det_mv3_pse_v2_0/det_mv3_pse.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
diff --git a/test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_db_v2_0/train_infer_python.txt
similarity index 96%
rename from test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt
rename to test_tipc/configs/det_r50_db_v2_0/train_infer_python.txt
index 11af0ad18e948d9fa1f325745988877125583658..1d0d9693a98524581bb17850bfdc81a2bc3e460c 100644
--- a/test_tipc/configs/det_r50_db_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_db_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:det_r50_db_v2.0
+model_name:det_r50_db_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
diff --git a/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml b/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml
similarity index 100%
rename from test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml
rename to test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml
diff --git a/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/train_infer_python.txt
similarity index 88%
rename from test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt
rename to test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/train_infer_python.txt
index 2d294fd3038f5506a28d637dbe1aba44b5da237b..92ded19d67c7f8111419897414a5212cb9b3614f 100644
--- a/test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:det_r50_dcn_fce_ctw_v2.0
+model_name:det_r50_dcn_fce_ctw_v2_0
python:python3.7
gpu_list:0
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
+norm_train:tools/train.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/det_r50_dcn_fce_ctw_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2.0/det_r50_vd_dcn_fce_ctw.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/det_r50_dcn_fce_ctw_v2_0/det_r50_vd_dcn_fce_ctw.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
diff --git a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml b/test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/det_r50_vd_sast_icdar2015.yml
similarity index 100%
rename from test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml
rename to test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/det_r50_vd_sast_icdar2015.yml
diff --git a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/train_infer_python.txt
similarity index 89%
rename from test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
rename to test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/train_infer_python.txt
index b70ef46b4afb3a39f3bbd3d6274f0135a0646a37..b01f1925b4bcfc7ddf4cae891378e0e10d021869 100644
--- a/test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:det_r50_vd_sast_icdar15_v2.0
+model_name:det_r50_vd_sast_icdar15_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml -o
+norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/det_r50_vd_sast_icdar2015.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/det_r50_vd_sast_icdar2015.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
inference_dir:null
train_model:./inference/det_r50_vd_sast_icdar15_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2.0/det_r50_vd_sast_icdar2015.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_icdar15_v2_0/det_r50_vd_sast_icdar2015.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
diff --git a/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/det_r50_vd_sast_totaltext.yml b/test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/det_r50_vd_sast_totaltext.yml
similarity index 100%
rename from test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/det_r50_vd_sast_totaltext.yml
rename to test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/det_r50_vd_sast_totaltext.yml
diff --git a/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt b/test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/train_infer_python.txt
similarity index 88%
rename from test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt
rename to test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/train_infer_python.txt
index 7be5af7ddee0ed0f688980f5d5dca5a99c9705a0..a47ad6803053242fa8f6e6c6063e3fd2625d97c8 100644
--- a/test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:det_r50_vd_sast_totaltext_v2.0
+model_name:det_r50_vd_sast_totaltext_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./pretrain_models/ResNet50_vd_ssld_pretrained
+norm_train:tools/train.py -c test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./pretrain_models/ResNet50_vd_ssld_pretrained
pact_train:null
fpgm_train:null
distill_train:null
@@ -27,7 +27,7 @@ null:null
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/det_r50_vd_sast_totaltext.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/det_r50_vd_sast_totaltext.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
inference_dir:null
train_model:./inference/det_r50_vd_sast_totaltext_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_totaltext_v2.0/det_r50_vd_sast_totaltext.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/det_r50_vd_sast_totaltext_v2_0/det_r50_vd_sast_totaltext.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
diff --git a/test_tipc/configs/layoutxlm_ser/train_infer_python.txt b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6d05d413e106eee873b026d60fb4320c61f833c4
--- /dev/null
+++ b/test_tipc/configs/layoutxlm_ser/train_infer_python.txt
@@ -0,0 +1,59 @@
+===========================train_params===========================
+model_name:layoutxlm_ser
+python:python3.7
+gpu_list:0|0,1
+Global.use_gpu:True|True
+Global.auto_cast:fp32
+Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
+Global.save_model_dir:./output/
+Train.loader.batch_size_per_card:lite_train_lite_infer=4|whole_train_whole_infer=8
+Architecture.Backbone.checkpoints:null
+train_model_name:latest
+train_infer_img_dir:ppstructure/docs/vqa/input/zh_val_42.jpg
+null:null
+##
+trainer:norm_train
+norm_train:tools/train.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o Global.print_batch_step=1 Global.eval_batch_step=[1000,1000] Train.loader.shuffle=false
+pact_train:null
+fpgm_train:null
+distill_train:null
+null:null
+null:null
+##
+===========================eval_params===========================
+eval:null
+null:null
+##
+===========================infer_params===========================
+Global.save_inference_dir:./output/
+Architecture.Backbone.checkpoints:
+norm_export:tools/export_model.py -c configs/vqa/ser/layoutxlm_xfund_zh.yml -o
+quant_export:
+fpgm_export:
+distill_export:null
+export1:null
+export2:null
+##
+infer_model:null
+infer_export:null
+infer_quant:False
+inference:ppstructure/vqa/predict_vqa_token_ser.py --vqa_algorithm=LayoutXLM --ser_dict_path=train_data/XFUND/class_list_xfun.txt --output=output
+--use_gpu:True|False
+--enable_mkldnn:False
+--cpu_threads:6
+--rec_batch_num:1
+--use_tensorrt:False
+--precision:fp32
+--ser_model_dir:
+--image_dir:./ppstructure/docs/vqa/input/zh_val_42.jpg
+null:null
+--benchmark:False
+null:null
+===========================infer_benchmark_params==========================
+random_infer_input:[{float32,[3,224,224]}]
+===========================train_benchmark_params==========================
+batch_size:4
+fp_items:fp32|fp16
+epoch:3
+--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
+flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98
diff --git a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml
rename to test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/rec_icdar15_train.yml
diff --git a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/train_infer_python.txt
similarity index 89%
rename from test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/train_infer_python.txt
index 4e34a6a525fb8104407d04c617db39934b84e140..db89b4c78d72d1853096d6b44b73a7ca61792dfe 100644
--- a/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_mv3_none_bilstm_ctc_v2.0
+model_name:rec_mv3_none_bilstm_ctc_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
+norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o Global.print_batch_step=4 Train.loader.shuffle=false
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_mv3_none_bilstm_ctc_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_mv3_none_none_ctc_v2_0/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml
rename to test_tipc/configs/rec_mv3_none_none_ctc_v2_0/rec_icdar15_train.yml
diff --git a/test_tipc/configs/rec_mv3_none_none_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_none_none_ctc_v2_0/train_infer_python.txt
similarity index 87%
rename from test_tipc/configs/rec_mv3_none_none_ctc_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_mv3_none_none_ctc_v2_0/train_infer_python.txt
index 593de3ff20aa9890e7d9a02a9e5ca5b130e5a266..003e91ff3d95e62d4353d7c4545e780ecd2f9708 100644
--- a/test_tipc/configs/rec_mv3_none_none_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_none_none_ctc_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_mv3_none_none_ctc_v2.0
+model_name:rec_mv3_none_none_ctc_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2_0/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2_0/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2_0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_mv3_none_none_ctc_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_none_none_ctc_v2_0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/rec_mv3_tps_bilstm_att.yml
similarity index 100%
rename from test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml
rename to test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/rec_mv3_tps_bilstm_att.yml
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/train_infer_python.txt
similarity index 88%
rename from test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/train_infer_python.txt
index 1b2d9abb0f00467ce92c4f51f97c283bc3e85c5e..c7b416c83323863a905929a2effcb1d3ad856422 100644
--- a/test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_mv3_tps_bilstm_att_v2.0
+model_name:rec_mv3_tps_bilstm_att_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/rec_mv3_tps_bilstm_att.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/rec_mv3_tps_bilstm_att.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/rec_mv3_tps_bilstm_att.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_mv3_tps_bilstm_att_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2.0/rec_mv3_tps_bilstm_att.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_att_v2_0/rec_mv3_tps_bilstm_att.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="RARE" --min_subgraph_size=5
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml
rename to test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml
diff --git a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/train_infer_python.txt
similarity index 89%
rename from test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/train_infer_python.txt
index 1367c7abd4c9ca5b0c6f1eb291dd2af8d9fa4de4..0c6e2d1da7f163521e8859bd8c96436b2a6bac64 100644
--- a/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_mv3_tps_bilstm_ctc_v2.0
+model_name:rec_mv3_tps_bilstm_ctc_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_mv3_tps_bilstm_ctc_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_mv3_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="StarNet"
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml
index d0cb20481f56a093f96c3d13f5fa2c2d13ae0c69..21d56b685c7da7b1db43acb6570bf7f40d0426fa 100644
--- a/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml
+++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/rec_r32_gaspin_bilstm_att.yml
@@ -8,7 +8,7 @@ Global:
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [0, 2000]
cal_metric_during_train: True
- pretrained_model:
+ pretrained_model: pretrain_models/rec_r32_gaspin_bilstm_att_train/best_accuracy
checkpoints:
save_inference_dir:
use_visualdl: False
diff --git a/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt
index 4915055a576f0a5c1f7b0935a31d1d3c266903a5..115dfd661abc64db9e14c629f79099be7b6ff0e0 100644
--- a/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt
+++ b/test_tipc/configs/rec_r32_gaspin_bilstm_att/train_infer_python.txt
@@ -1,6 +1,6 @@
===========================train_params===========================
model_name:rec_r32_gaspin_bilstm_att
-python:python
+python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
@@ -39,11 +39,11 @@ infer_export:tools/export_model.py -c test_tipc/configs/rec_r32_gaspin_bilstm_at
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/dict/spin_dict.txt --use_space_char=False --rec_image_shape="3,32,100" --rec_algorithm="SPIN"
--use_gpu:True|False
---enable_mkldnn:True|False
---cpu_threads:1|6
+--enable_mkldnn:False
+--cpu_threads:6
--rec_batch_num:1|6
---use_tensorrt:False|False
---precision:fp32|int8
+--use_tensorrt:False
+--precision:fp32
--rec_model_dir:
--image_dir:./inference/rec_inference
--save_log_path:./test/output/
diff --git a/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml
rename to test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/rec_icdar15_train.yml
diff --git a/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/train_infer_python.txt
similarity index 86%
rename from test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/train_infer_python.txt
index 46aa3d719051a4f124583f88709026569d95c1c7..07a6190b0ef09da5cd20b9dd8ea922544c578710 100644
--- a/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_r34_vd_none_bilstm_ctc_v2.0
+model_name:rec_r34_vd_none_bilstm_ctc_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_r34_vd_none_bilstm_ctc_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml
rename to test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/rec_icdar15_train.yml
diff --git a/test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/train_infer_python.txt
similarity index 86%
rename from test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/train_infer_python.txt
index 3e066d7b72a6a707322b3aabe41ca6d698496433..145793aa472d8330daf9321f44692a03e7ef6354 100644
--- a/test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_r34_vd_none_none_ctc_v2.0
+model_name:rec_r34_vd_none_none_ctc_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_r34_vd_none_none_ctc_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2.0/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_none_none_ctc_v2_0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100"
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/rec_r34_vd_tps_bilstm_att.yml
similarity index 100%
rename from test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml
rename to test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/rec_r34_vd_tps_bilstm_att.yml
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/train_infer_python.txt
similarity index 87%
rename from test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/train_infer_python.txt
index 1e4f46633efbf36fc78ed2beb7ed883d1483b3b0..759518a4a11a17e076401bb8dd193617c9f10530 100644
--- a/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_r34_vd_tps_bilstm_att_v2.0
+model_name:rec_r34_vd_tps_bilstm_att_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/rec_r34_vd_tps_bilstm_att.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/rec_r34_vd_tps_bilstm_att.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/rec_r34_vd_tps_bilstm_att.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_r34_vd_tps_bilstm_att_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2.0/rec_r34_vd_tps_bilstm_att.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_att_v2_0/rec_r34_vd_tps_bilstm_att.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="RARE" --min_subgraph_size=5
--use_gpu:True|False
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml
similarity index 100%
rename from test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml
rename to test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml
diff --git a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/train_infer_python.txt
similarity index 88%
rename from test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt
rename to test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/train_infer_python.txt
index 9e795b66453039696ed5eedb92fba5e25150413c..ecc898341ce14dfed0de4290b798dd70078ae2da 100644
--- a/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/train_infer_python.txt
+++ b/test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/train_infer_python.txt
@@ -1,5 +1,5 @@
===========================train_params===========================
-model_name:rec_r34_vd_tps_bilstm_ctc_v2.0
+model_name:rec_r34_vd_tps_bilstm_ctc_v2_0
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
@@ -13,7 +13,7 @@ train_infer_img_dir:./inference/rec_inference
null:null
##
trainer:norm_train
-norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_train:tools/train.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
pact_train:null
fpgm_train:null
distill_train:null
@@ -21,13 +21,13 @@ null:null
null:null
##
===========================eval_params===========================
-eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+eval:tools/eval.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
-norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+norm_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
quant_export:null
fpgm_export:null
distill_export:null
@@ -35,7 +35,7 @@ export1:null
export2:null
##
train_model:./inference/rec_r34_vd_tps_bilstm_ctc_v2.0_train/best_accuracy
-infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2.0/rec_icdar15_train.yml -o
+infer_export:tools/export_model.py -c test_tipc/configs/rec_r34_vd_tps_bilstm_ctc_v2_0/rec_icdar15_train.yml -o
infer_quant:False
inference:tools/infer/predict_rec.py --rec_char_dict_path=./ppocr/utils/ic15_dict.txt --rec_image_shape="3,32,100" --rec_algorithm="StarNet"
--use_gpu:True|False
diff --git a/test_tipc/docs/benchmark_train.md b/test_tipc/docs/benchmark_train.md
index a7f95eb6c530e1c451bb400cdb193694e2aee5f6..50cc13b92fc9b566c95d61cac00e587547a94811 100644
--- a/test_tipc/docs/benchmark_train.md
+++ b/test_tipc/docs/benchmark_train.md
@@ -69,7 +69,8 @@ train_log/
| det_r50_vd_east_v2.0 |[config](../configs/det_r50_vd_east_v2.0/train_infer_python.txt) | 42.485 | 42.624 / 42.663 / 42.561 |0.00239083 | 67.61 |67.825/ 68.299/ 68.51| 0.00999854 | 10,000| 2,000|
| det_r50_vd_pse_v2.0 |[config](../configs/det_r50_vd_pse_v2.0/train_infer_python.txt) | 16.455 | 16.517 / 16.555 / 16.353 |0.012201752 | 27.02 |27.288 / 27.152 / 27.408| 0.009340339 | 10,000| 2,000|
| rec_mv3_none_bilstm_ctc_v2.0 |[config](../configs/rec_mv3_none_bilstm_ctc_v2.0/train_infer_python.txt) | 2288.358 | 2291.906 / 2293.725 / 2290.05 |0.001602197 | 2336.17 |2327.042 / 2328.093 / 2344.915| 0.007622025 | 600,000| 160,000|
+| layoutxlm_ser |[config](../configs/layoutxlm/train_infer_python.txt) | 18.001 | 18.114 / 18.107 / 18.307 |0.010924783 | 21.982 | 21.507 / 21.116 / 21.406| 0.018180127 | 1490 | 1490|
| PP-Structure-table |[config](../configs/en_table_structure/train_infer_python.txt) | 14.151 | 14.077 / 14.23 / 14.25 |0.012140351 | 16.285 | 16.595 / 16.878 / 16.531 | 0.020559308 | 20,000| 5,000|
| det_r50_dcn_fce_ctw_v2.0 |[config](../configs/det_r50_dcn_fce_ctw_v2.0/train_infer_python.txt) | 14.057 | 14.029 / 14.02 / 14.014 |0.001069214 | 18.298 |18.411 / 18.376 / 18.331| 0.004345228 | 10,000| 2,000|
| ch_PP-OCRv3_det |[config](../configs/ch_PP-OCRv3_det/train_infer_python.txt) | 8.622 | 8.431 / 8.423 / 8.479|0.006604552 | 14.203 |14.346 14.468 14.23| 0.016450097 | 10,000| 2,000|
-| ch_PP-OCRv3_rec |[config](../configs/ch_PP-OCRv3_rec/train_infer_python.txt) | 73.627 | 72.46 / 73.575 / 73.704|0.016878324 | | | | 160,000| 40,000|
\ No newline at end of file
+| ch_PP-OCRv3_rec |[config](../configs/ch_PP-OCRv3_rec/train_infer_python.txt) | 90.239 | 90.077 / 91.513 / 91.325|0.01569176 | | | | 160,000| 40,000|
diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh
index cb3fa2440d9672ba113904bd1548d458491d1d8c..76543f39e4952b40368cdd392acc430dda8fcd9b 100644
--- a/test_tipc/prepare.sh
+++ b/test_tipc/prepare.sh
@@ -22,7 +22,7 @@ trainer_list=$(func_parser_value "${lines[14]}")
if [ ${MODE} = "benchmark_train" ];then
pip install -r requirements.txt
- if [[ ${model_name} =~ "ch_ppocr_mobile_v2.0_det" || ${model_name} =~ "det_mv3_db_v2_0" ]];then
+ if [[ ${model_name} =~ "ch_ppocr_mobile_v2_0_det" || ${model_name} =~ "det_mv3_db_v2_0" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
rm -rf ./train_data/icdar2015
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
@@ -30,7 +30,7 @@ if [ ${MODE} = "benchmark_train" ];then
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi
- if [[ ${model_name} =~ "ch_ppocr_server_v2.0_det" || ${model_name} =~ "ch_PP-OCRv3_det" ]];then
+ if [[ ${model_name} =~ "ch_ppocr_server_v2_0_det" || ${model_name} =~ "ch_PP-OCRv3_det" ]];then
rm -rf ./train_data/icdar2015
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_benckmark.tar
@@ -55,7 +55,7 @@ if [ ${MODE} = "benchmark_train" ];then
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi
- if [[ ${model_name} =~ "det_r50_db_v2.0" || ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then
+ if [[ ${model_name} =~ "det_r50_db_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
rm -rf ./train_data/icdar2015
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/icdar2015_benckmark.tar --no-check-certificate
@@ -71,13 +71,23 @@ if [ ${MODE} = "benchmark_train" ];then
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi
- if [[ ${model_name} =~ "ch_ppocr_mobile_v2.0_rec" || ${model_name} =~ "ch_ppocr_server_v2.0_rec" || ${model_name} =~ "ch_PP-OCRv2_rec" || ${model_name} =~ "rec_mv3_none_bilstm_ctc_v2.0" || ${model_name} =~ "ch_PP-OCRv3_rec" ]];then
- rm -rf ./train_data/ic15_data_benckmark
+ if [[ ${model_name} =~ "ch_ppocr_mobile_v2_0_rec" || ${model_name} =~ "ch_ppocr_server_v2_0_rec" || ${model_name} =~ "ch_PP-OCRv2_rec" || ${model_name} =~ "rec_mv3_none_bilstm_ctc_v2_0" || ${model_name} =~ "ch_PP-OCRv3_rec" ]];then
+ rm -rf ./train_data/ic15_data
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ic15_data_benckmark.tar --no-check-certificate
cd ./train_data/ && tar xf ic15_data_benckmark.tar
ln -s ./ic15_data_benckmark ./ic15_data
cd ../
fi
+ if [[ ${model_name} =~ "ch_PP-OCRv2_rec" || ${model_name} =~ "ch_PP-OCRv3_rec" ]];then
+ rm -rf ./train_data/ic15_data
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ic15_data_benckmark.tar --no-check-certificate
+ cd ./train_data/ && tar xf ic15_data_benckmark.tar
+ ln -s ./ic15_data_benckmark ./ic15_data
+ cd ic15_data
+ mv rec_gt_train4w.txt rec_gt_train.txt
+ cd ../
+ cd ../
+ fi
if [[ ${model_name} == "en_table_structure" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf en_ppocr_mobile_v2.0_table_structure_train.tar && cd ../
@@ -87,7 +97,7 @@ if [ ${MODE} = "benchmark_train" ];then
ln -s ./pubtabnet_benckmark ./pubtabnet
cd ../
fi
- if [[ ${model_name} == "det_r50_dcn_fce_ctw_v2.0" ]]; then
+ if [[ ${model_name} == "det_r50_dcn_fce_ctw_v2_0" ]]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar && cd ../
rm -rf ./train_data/icdar2015
@@ -96,6 +106,19 @@ if [ ${MODE} = "benchmark_train" ];then
ln -s ./icdar2015_benckmark ./icdar2015
cd ../
fi
+ if [ ${model_name} == "layoutxlm_ser" ]; then
+ pip install -r ppstructure/vqa/requirements.txt
+ pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
+ cd ./train_data/ && tar xf XFUND.tar
+ # expand gt.txt 10 times
+ cd XFUND/zh_train
+ for i in `seq 10`;do cp train.json dup$i.txt;done
+ cat dup* > train.json && rm -rf dup*
+ cd ../../
+
+ cd ../
+ fi
fi
if [ ${MODE} = "lite_train_lite_infer" ];then
@@ -161,7 +184,7 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
cd ./pretrain_models/ && tar xf en_server_pgnetA.tar && cd ../
cd ./train_data && tar xf total_text_lite.tar && ln -s total_text_lite total_text && cd ../
fi
- if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ] || [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
+ if [ ${model_name} == "det_r50_vd_sast_icdar15_v2_0" ] || [ ${model_name} == "det_r50_vd_sast_totaltext_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
@@ -172,16 +195,16 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "det_r50_db_v2.0" ]; then
+ if [ ${model_name} == "det_r50_db_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_FPGM" ]; then
+ if [ ${model_name} == "ch_ppocr_mobile_v2_0_rec_FPGM" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../
fi
- if [ ${model_name} == "det_mv3_east_v2.0" ]; then
+ if [ ${model_name} == "det_mv3_east_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_mv3_east_v2.0_train.tar && cd ../
fi
@@ -189,10 +212,21 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "det_r50_dcn_fce_ctw_v2.0" ]; then
+ if [ ${model_name} == "det_r50_dcn_fce_ctw_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar & cd ../
fi
+ if [ ${model_name} == "rec_r32_gaspin_bilstm_att" ]; then
+ wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/rec_r32_gaspin_bilstm_att_train.tar --no-check-certificate
+ cd ./pretrain_models/ && tar xf rec_r32_gaspin_bilstm_att_train.tar && cd ../
+ fi
+ if [ ${model_name} == "layoutxlm_ser" ]; then
+ pip install -r ppstructure/vqa/requirements.txt
+ pip install paddlenlp\>=2.3.5 --force-reinstall -i https://mirrors.aliyun.com/pypi/simple/
+ wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/ppstructure/dataset/XFUND.tar --no-check-certificate
+ cd ./train_data/ && tar xf XFUND.tar
+ cd ../
+ fi
elif [ ${MODE} = "whole_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
@@ -220,7 +254,7 @@ elif [ ${MODE} = "whole_train_whole_infer" ];then
cd ./pretrain_models/ && tar xf en_server_pgnetA.tar && cd ../
cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../
fi
- if [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
+ if [ ${model_name} == "det_r50_vd_sast_totaltext_v2_0" ]; then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/total_text_lite.tar --no-check-certificate
cd ./train_data && tar xf total_text.tar && ln -s total_text_lite total_text && cd ../
@@ -264,32 +298,32 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
cd ./inference && tar xf rec_inference.tar && tar xf ch_det_data_50.tar && cd ../
- if [ ${model_name} = "ch_ppocr_mobile_v2.0_det" ]; then
+ if [ ${model_name} = "ch_ppocr_mobile_v2_0_det" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_det_train"
rm -rf ./train_data/icdar2015
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_det_PACT" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_det_PACT" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_det_prune_infer"
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_server_v2.0_det" ]; then
+ elif [ ${model_name} = "ch_ppocr_server_v2_0_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_train.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ elif [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_PACT" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_PACT" ]; then
eval_model_name="ch_ppocr_mobile_v2.0_rec_slim_infer"
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_FPGM" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_FPGM" ]; then
eval_model_name="ch_PP-OCRv2_rec_infer"
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ${eval_model_name}.tar && cd ../
@@ -334,39 +368,39 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar --no-check-certificate
cd ./inference && tar xf en_server_pgnetA.tar && tar xf ch_det_data_50.tar && cd ../
fi
- if [ ${model_name} == "det_r50_vd_sast_icdar15_v2.0" ]; then
+ if [ ${model_name} == "det_r50_vd_sast_icdar15_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_sast_icdar15_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
- if [ ${model_name} == "rec_mv3_none_none_ctc_v2.0" ]; then
+ if [ ${model_name} == "rec_mv3_none_none_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_none_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mv3_none_none_ctc_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "rec_r34_vd_none_none_ctc_v2.0" ]; then
+ if [ ${model_name} == "rec_r34_vd_none_none_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_none_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_r34_vd_none_none_ctc_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "rec_mv3_none_bilstm_ctc_v2.0" ]; then
+ if [ ${model_name} == "rec_mv3_none_bilstm_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "rec_r34_vd_none_bilstm_ctc_v2.0" ]; then
+ if [ ${model_name} == "rec_r34_vd_none_bilstm_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_none_bilstm_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_r34_vd_none_bilstm_ctc_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "rec_mv3_tps_bilstm_ctc_v2.0" ]; then
+ if [ ${model_name} == "rec_mv3_tps_bilstm_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mv3_tps_bilstm_ctc_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "rec_r34_vd_tps_bilstm_ctc_v2.0" ]; then
+ if [ ${model_name} == "rec_r34_vd_tps_bilstm_ctc_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "ch_ppocr_server_v2.0_rec" ]; then
+ if [ ${model_name} == "ch_ppocr_server_v2_0_rec" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar --no-check-certificate
cd ./inference/ && tar xf ch_ppocr_server_v2.0_rec_train.tar && cd ../
fi
- if [ ${model_name} == "ch_ppocr_mobile_v2.0_rec" ]; then
+ if [ ${model_name} == "ch_ppocr_mobile_v2_0_rec" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar --no-check-certificate
cd ./inference/ && tar xf ch_ppocr_mobile_v2.0_rec_train.tar && cd ../
fi
@@ -374,11 +408,11 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mtb_nrtr_train.tar && cd ../
fi
- if [ ${model_name} == "rec_mv3_tps_bilstm_att_v2.0" ]; then
+ if [ ${model_name} == "rec_mv3_tps_bilstm_att_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_mv3_tps_bilstm_att_v2.0_train.tar && cd ../
fi
- if [ ${model_name} == "rec_r34_vd_tps_bilstm_att_v2.0" ]; then
+ if [ ${model_name} == "rec_r34_vd_tps_bilstm_att_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf rec_r34_vd_tps_bilstm_att_v2.0_train.tar && cd ../
fi
@@ -391,7 +425,7 @@ elif [ ${MODE} = "whole_infer" ];then
cd ./inference/ && tar xf rec_r50_vd_srn_train.tar && cd ../
fi
- if [ ${model_name} == "det_r50_vd_sast_totaltext_v2.0" ]; then
+ if [ ${model_name} == "det_r50_vd_sast_totaltext_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_sast_totaltext_v2.0_train.tar && cd ../
fi
@@ -399,11 +433,11 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
- if [ ${model_name} == "det_r50_db_v2.0" ]; then
+ if [ ${model_name} == "det_r50_db_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_db_v2.0_train.tar && tar xf ch_det_data_50.tar && cd ../
fi
- if [ ${model_name} == "det_mv3_pse_v2.0" ]; then
+ if [ ${model_name} == "det_mv3_pse_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_pse_v2.0_train.tar & cd ../
fi
@@ -411,7 +445,7 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_pse_v2.0_train.tar & cd ../
fi
- if [ ${model_name} == "det_mv3_east_v2.0" ]; then
+ if [ ${model_name} == "det_mv3_east_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_mv3_east_v2.0_train.tar & cd ../
fi
@@ -419,7 +453,7 @@ elif [ ${MODE} = "whole_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_vd_east_v2.0_train.tar & cd ../
fi
- if [ ${model_name} == "det_r50_dcn_fce_ctw_v2.0" ]; then
+ if [ ${model_name} == "det_r50_dcn_fce_ctw_v2_0" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/contribution/det_r50_dcn_fce_ctw_v2.0_train.tar --no-check-certificate
cd ./inference/ && tar xf det_r50_dcn_fce_ctw_v2.0_train.tar & cd ../
fi
@@ -434,7 +468,7 @@ fi
if [[ ${model_name} =~ "KL" ]]; then
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015_lite.tar && rm -rf ./icdar2015 && ln -s ./icdar2015_lite ./icdar2015 && cd ../
- if [ ${model_name} = "ch_ppocr_mobile_v2.0_det_KL" ]; then
+ if [ ${model_name} = "ch_ppocr_mobile_v2_0_det_KL" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
@@ -466,7 +500,7 @@ if [[ ${model_name} =~ "KL" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
fi
- if [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_KL" ]; then
+ if [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_KL" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar --no-check-certificate
@@ -484,35 +518,35 @@ if [[ ${model_name} =~ "KL" ]]; then
fi
if [ ${MODE} = "cpp_infer" ];then
- if [ ${model_name} = "ch_ppocr_mobile_v2.0_det" ]; then
+ if [ ${model_name} = "ch_ppocr_mobile_v2_0_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_det_KL" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_det_KL" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_klquant_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_klquant_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_det_PACT" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_det_PACT" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_pact_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf rec_inference.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_KL" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_KL" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_klquant_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_rec_klquant_infer.tar && tar xf rec_inference.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0_rec_PACT" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0_rec_PACT" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_rec_pact_infer.tar && tar xf rec_inference.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_server_v2.0_det" ]; then
+ elif [ ${model_name} = "ch_ppocr_server_v2_0_det" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_server_v2.0_rec" ]; then
+ elif [ ${model_name} = "ch_ppocr_server_v2_0_rec" ]; then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf rec_inference.tar && cd ../
@@ -564,12 +598,12 @@ if [ ${MODE} = "cpp_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/rec_inference.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_rec_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_rec_pact_infer.tar && tar xf rec_inference.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_mobile_v2.0" ]; then
+ elif [ ${model_name} = "ch_ppocr_mobile_v2_0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../
- elif [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ elif [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
@@ -597,7 +631,7 @@ if [ ${MODE} = "serving_infer" ];then
${python_name} -m pip install paddle_serving_client
${python_name} -m pip install paddle-serving-app
# wget model
- if [ ${model_name} == "ch_ppocr_mobile_v2.0_det_KL" ] || [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_KL" ] ; then
+ if [ ${model_name} == "ch_ppocr_mobile_v2_0_det_KL" ] || [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_KL" ] ; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_klquant_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_klquant_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_klquant_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_klquant_infer.tar && cd ../
@@ -609,7 +643,7 @@ if [ ${MODE} = "serving_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_det_klquant_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_rec_klquant_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_klquant_infer.tar && tar xf ch_PP-OCRv3_rec_klquant_infer.tar && cd ../
- elif [ ${model_name} == "ch_ppocr_mobile_v2.0_det_PACT" ] || [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_PACT" ] ; then
+ elif [ ${model_name} == "ch_ppocr_mobile_v2_0_det_PACT" ] || [ ${model_name} == "ch_ppocr_mobile_v2.0_rec_PACT" ] ; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_det_pact_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_ppocr_mobile_v2.0_rec_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_pact_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_pact_infer.tar && cd ../
@@ -621,11 +655,11 @@ if [ ${MODE} = "serving_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_det_pact_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/tipc_fake_model/ch_PP-OCRv3_rec_pact_infer.tar --no-check-certificate
cd ./inference && tar xf ch_PP-OCRv3_det_pact_infer.tar && tar xf ch_PP-OCRv3_rec_pact_infer.tar && cd ../
- elif [[ ${model_name} =~ "ch_ppocr_mobile_v2.0" ]]; then
+ elif [[ ${model_name} =~ "ch_ppocr_mobile_v2_0" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && cd ../
- elif [[ ${model_name} =~ "ch_ppocr_server_v2.0" ]]; then
+ elif [[ ${model_name} =~ "ch_ppocr_server_v2_0" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && cd ../
@@ -650,11 +684,11 @@ if [ ${MODE} = "paddle2onnx_infer" ];then
${python_name} -m pip install paddle2onnx
${python_name} -m pip install onnxruntime
# wget model
- if [[ ${model_name} =~ "ch_ppocr_mobile_v2.0" ]]; then
+ if [[ ${model_name} =~ "ch_ppocr_mobile_v2_0" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && cd ../
- elif [[ ${model_name} =~ "ch_ppocr_server_v2.0" ]]; then
+ elif [[ ${model_name} =~ "ch_ppocr_server_v2_0" ]]; then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar --no-check-certificate
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar --no-check-certificate
cd ./inference && tar xf ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && cd ../
diff --git a/test_tipc/prepare_lite_cpp.sh b/test_tipc/prepare_lite_cpp.sh
index 9148cb5dd72e16790e10db1cb266e4169cd4fdab..0d3a5ca45a0dc37c90eac3f3d310bae225bb4cde 100644
--- a/test_tipc/prepare_lite_cpp.sh
+++ b/test_tipc/prepare_lite_cpp.sh
@@ -49,7 +49,7 @@ model_path=./inference_models
for model in ${lite_model_list[*]}; do
if [[ $model =~ "PP-OCRv2" ]]; then
inference_model_url=https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/${model}.tar
- elif [[ $model =~ "v2.0" ]]; then
+ elif [[ $model =~ "v2_0" ]]; then
inference_model_url=https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/${model}.tar
elif [[ $model =~ "PP-OCRv3" ]]; then
inference_model_url=https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/${model}.tar
diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh
index 356bc98041fffa8f0437c6419fc72c06d5e719f7..78d79d0b8eaac782f98c1e883d091a001443f41a 100644
--- a/test_tipc/test_paddle2onnx.sh
+++ b/test_tipc/test_paddle2onnx.sh
@@ -54,7 +54,7 @@ function func_paddle2onnx(){
_script=$1
# paddle2onnx
- if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2_0" ] || [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
# trans det
set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}")
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
@@ -113,7 +113,7 @@ function func_paddle2onnx(){
_save_log_path="${LOG_PATH}/paddle2onnx_infer_cpu.log"
set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu}")
set_img_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
- if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2_0" ] || [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
@@ -132,7 +132,7 @@ function func_paddle2onnx(){
_save_log_path="${LOG_PATH}/paddle2onnx_infer_gpu.log"
set_gpu=$(func_set_params "${use_gpu_key}" "${use_gpu}")
set_img_dir=$(func_set_params "${image_dir_key}" "${image_dir_value}")
- if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2_0" ] || [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}")
set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}")
infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 "
diff --git a/test_tipc/test_serving_infer_python.sh b/test_tipc/test_serving_infer_python.sh
index 4ccccc06e23ce086e7dac1f3446aae9130605444..4b7dfcf785a3c8459cce95d55744dbcd4f97027a 100644
--- a/test_tipc/test_serving_infer_python.sh
+++ b/test_tipc/test_serving_infer_python.sh
@@ -71,7 +71,7 @@ function func_serving(){
# pdserving
set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}")
set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}")
- if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2_0" ] || [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
# trans det
set_dirname=$(func_set_params "--dirname" "${det_infer_model_dir_value}")
set_serving_server=$(func_set_params "--serving_server" "${det_serving_server_value}")
@@ -120,7 +120,7 @@ function func_serving(){
for threads in ${web_cpu_threads_list[*]}; do
set_cpu_threads=$(func_set_params "${web_cpu_threads_key}" "${threads}")
server_log_path="${LOG_PATH}/python_server_cpu_usemkldnn_${use_mkldnn}_threads_${threads}.log"
- if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2_0" ] || [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${web_use_gpu_key}="" ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
@@ -171,7 +171,7 @@ function func_serving(){
device_type=2
fi
set_precision=$(func_set_params "${web_precision_key}" "${precision}")
- if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2.0" ] || [ ${model_name} = "ch_ppocr_server_v2.0" ]; then
+ if [ ${model_name} = "ch_PP-OCRv2" ] || [ ${model_name} = "ch_PP-OCRv3" ] || [ ${model_name} = "ch_ppocr_mobile_v2_0" ] || [ ${model_name} = "ch_ppocr_server_v2_0" ]; then
set_det_model_config=$(func_set_params "${det_server_key}" "${det_server_value}")
set_rec_model_config=$(func_set_params "${rec_server_key}" "${rec_server_value}")
web_service_cmd="nohup ${python} ${web_service_py} ${set_tensorrt} ${set_precision} ${set_det_model_config} ${set_rec_model_config} > ${server_log_path} 2>&1 &"
diff --git a/test_tipc/test_train_inference_python.sh b/test_tipc/test_train_inference_python.sh
index 402f636b1b92fa75380142803c6b513a897a89e4..545cdbba2051c8123ef7f70f2aeb4b4b5a57b7c5 100644
--- a/test_tipc/test_train_inference_python.sh
+++ b/test_tipc/test_train_inference_python.sh
@@ -101,6 +101,7 @@ function func_inference(){
_log_path=$4
_img_dir=$5
_flag_quant=$6
+ _gpu=$7
# inference
for use_gpu in ${use_gpu_list[*]}; do
if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then
@@ -119,7 +120,7 @@ function func_inference(){
fi # skip when quant model inference but precision is not int8
set_precision=$(func_set_params "${precision_key}" "${precision}")
- _save_log_path="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
+ _save_log_path="${_log_path}/python_infer_cpu_gpus_${_gpu}_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
@@ -150,7 +151,7 @@ function func_inference(){
continue
fi
for batch_size in ${batch_size_list[*]}; do
- _save_log_path="${_log_path}/python_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
+ _save_log_path="${_log_path}/python_infer_gpu_gpus_${_gpu}_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
@@ -184,6 +185,7 @@ if [ ${MODE} = "whole_infer" ]; then
# set CUDA_VISIBLE_DEVICES
eval $env
export Count=0
+ gpu=0
IFS="|"
infer_run_exports=(${infer_export_list})
infer_quant_flag=(${infer_is_quant})
@@ -205,7 +207,7 @@ if [ ${MODE} = "whole_infer" ]; then
fi
#run inference
is_quant=${infer_quant_flag[Count]}
- func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant}
+ func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant} "${gpu}"
Count=$(($Count + 1))
done
else
@@ -328,7 +330,7 @@ else
else
infer_model_dir=${save_infer_path}
fi
- func_inference "${python}" "${inference_py}" "${infer_model_dir}" "${LOG_PATH}" "${train_infer_img_dir}" "${flag_quant}"
+ func_inference "${python}" "${inference_py}" "${infer_model_dir}" "${LOG_PATH}" "${train_infer_img_dir}" "${flag_quant}" "${gpu}"
eval "unset CUDA_VISIBLE_DEVICES"
fi
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 7eb77dec74bf283936e1143edcb5b5dfc28365bd..9345106e774cfbcf0e87a7cf5d8b6cdabb4cf490 100644
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -38,6 +38,7 @@ def init_args():
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=15)
+ parser.add_argument("--shape_info_filename", type=str, default=None)
parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500)
@@ -204,9 +205,18 @@ def create_predictor(args, mode, logger):
workspace_size=1 << 30,
precision_mode=precision,
max_batch_size=args.max_batch_size,
- min_subgraph_size=args.min_subgraph_size,
+ min_subgraph_size=args.min_subgraph_size, # skip the minmum trt subgraph
use_calib_mode=False)
- # skip the minmum trt subgraph
+
+ # collect shape
+ if args.shape_info_filename is not None:
+ if not os.path.exists(args.shape_info_filename):
+ config.collect_shape_range_info(args.shape_info_filename)
+ logger.info(f"collect dynamic shape info into : {args.shape_info_filename}")
+ else:
+ logger.info(f"dynamic shape info file( {args.shape_info_filename} ) already exists, not need to generate again.")
+ config.enable_tuned_tensorrt_dynamic_shape(args.shape_info_filename, True)
+
use_dynamic_shape = True
if mode == "det":
min_input_shape = {
diff --git a/tools/infer_vqa_token_ser_re.py b/tools/infer_vqa_token_ser_re.py
index 20ab1fe176c3be75f7a7b01a8d77df6419c58c75..51378bdaeb03d4ec6d7684de80625c5029963745 100755
--- a/tools/infer_vqa_token_ser_re.py
+++ b/tools/infer_vqa_token_ser_re.py
@@ -113,10 +113,13 @@ def make_input(ser_inputs, ser_results):
class SerRePredictor(object):
def __init__(self, config, ser_config):
+ global_config = config['Global']
+ if "infer_mode" in global_config:
+ ser_config["Global"]["infer_mode"] = global_config["infer_mode"]
+
self.ser_engine = SerPredictor(ser_config)
# init re model
- global_config = config['Global']
# build post process
self.post_process_class = build_post_process(config['PostProcess'],
@@ -130,8 +133,8 @@ class SerRePredictor(object):
self.model.eval()
- def __call__(self, img_path):
- ser_results, ser_inputs = self.ser_engine({'img_path': img_path})
+ def __call__(self, data):
+ ser_results, ser_inputs = self.ser_engine(data)
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
preds = self.model(re_input)
post_result = self.post_process_class(
@@ -173,18 +176,33 @@ if __name__ == '__main__':
ser_re_engine = SerRePredictor(config, ser_config)
- infer_imgs = get_image_file_list(config['Global']['infer_img'])
+ if config["Global"].get("infer_mode", None) is False:
+ data_dir = config['Eval']['dataset']['data_dir']
+ with open(config['Global']['infer_img'], "rb") as f:
+ infer_imgs = f.readlines()
+ else:
+ infer_imgs = get_image_file_list(config['Global']['infer_img'])
+
with open(
os.path.join(config['Global']['save_res_path'],
"infer_results.txt"),
"w",
encoding='utf-8') as fout:
- for idx, img_path in enumerate(infer_imgs):
+ for idx, info in enumerate(infer_imgs):
+ if config["Global"].get("infer_mode", None) is False:
+ data_line = info.decode('utf-8')
+ substr = data_line.strip("\n").split("\t")
+ img_path = os.path.join(data_dir, substr[0])
+ data = {'img_path': img_path, 'label': substr[1]}
+ else:
+ img_path = info
+ data = {'img_path': img_path}
+
save_img_path = os.path.join(
config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
- result = ser_re_engine(img_path)
+ result = ser_re_engine(data)
result = result[0]
fout.write(img_path + "\t" + json.dumps(
{
diff --git a/tools/program.py b/tools/program.py
index b2052b116d2e5de79b59f5a30b6eed7ec859ccee..d799a7e656ccea1d9b7476d56edb9fe7dcf7efe4 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -161,7 +161,7 @@ def to_float32(preds):
if isinstance(preds[k], dict) or isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
- preds[k] = preds[k].astype(paddle.float32)
+ preds[k] = paddle.to_tensor(preds[k], dtype='float32')
elif isinstance(preds, list):
for k in range(len(preds)):
if isinstance(preds[k], dict):
@@ -169,9 +169,9 @@ def to_float32(preds):
elif isinstance(preds[k], list):
preds[k] = to_float32(preds[k])
else:
- preds[k] = preds[k].astype(paddle.float32)
+ preds[k] = paddle.to_tensor(preds[k], dtype='float32')
else:
- preds = preds.astype(paddle.float32)
+ preds = paddle.to_tensor(preds, dtype='float32')
return preds
diff --git a/tools/train.py b/tools/train.py
index 309d4bb9e6b0fbcc9dd93545877662d746ada086..dc8cae8a63744bb9bd486d9899680dbde9da1697 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -119,9 +119,6 @@ def main(config, device, logger, vdl_writer):
config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture'])
- if config['Global']['distributed']:
- model = paddle.DataParallel(model)
-
model = apply_to_static(model, config, logger)
# build loss
@@ -157,10 +154,13 @@ def main(config, device, logger, vdl_writer):
scaler = paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
- model, optimizer = paddle.amp.decorate(models=model, optimizers=optimizer, level='O2', master_weight=True)
+ model, optimizer = paddle.amp.decorate(
+ models=model, optimizers=optimizer, level='O2', master_weight=True)
else:
scaler = None
+ if config['Global']['distributed']:
+ model = paddle.DataParallel(model)
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,