提交 484bf2f7 编写于 作者: xuyang2233's avatar xuyang2233

modified SPINLabelEncode SPINLabelDecode

上级 f6142746
...@@ -60,7 +60,7 @@ Loss: ...@@ -60,7 +60,7 @@ Loss:
ignore_index: 0 ignore_index: 0
PostProcess: PostProcess:
name: SPINAttnLabelDecode name: SPINLabelDecode
use_space_char: False use_space_char: False
...@@ -78,7 +78,7 @@ Train: ...@@ -78,7 +78,7 @@ Train:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- SPINAttnLabelEncode: # Class handling label - SPINLabelEncode: # Class handling label
- SPINRecResizeImg: - SPINRecResizeImg:
image_shape: [100, 32] image_shape: [100, 32]
interpolation : 2 interpolation : 2
...@@ -101,7 +101,7 @@ Eval: ...@@ -101,7 +101,7 @@ Eval:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- SPINAttnLabelEncode: # Class handling label - SPINLabelEncode: # Class handling label
- SPINRecResizeImg: - SPINRecResizeImg:
image_shape: [100, 32] image_shape: [100, 32]
interpolation : 2 interpolation : 2
......
...@@ -1217,7 +1217,7 @@ class ABINetLabelEncode(BaseRecLabelEncode): ...@@ -1217,7 +1217,7 @@ class ABINetLabelEncode(BaseRecLabelEncode):
dict_character = ['</s>'] + dict_character dict_character = ['</s>'] + dict_character
return dict_character return dict_character
class SPINAttnLabelEncode(AttnLabelEncode): class SPINLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
def __init__(self, def __init__(self,
...@@ -1226,7 +1226,7 @@ class SPINAttnLabelEncode(AttnLabelEncode): ...@@ -1226,7 +1226,7 @@ class SPINAttnLabelEncode(AttnLabelEncode):
use_space_char=False, use_space_char=False,
lower=True, lower=True,
**kwargs): **kwargs):
super(SPINAttnLabelEncode, self).__init__( super(SPINLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char) max_text_length, character_dict_path, use_space_char)
self.lower = lower self.lower = lower
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
......
...@@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess ...@@ -28,7 +28,7 @@ from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \ DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \ SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode, \
SPINAttnLabelDecode SPINLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
...@@ -45,7 +45,7 @@ def build_post_process(config, global_config=None): ...@@ -45,7 +45,7 @@ def build_post_process(config, global_config=None):
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode', 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
'TableMasterLabelDecode', 'SPINAttnLabelDecode' 'TableMasterLabelDecode', 'SPINLabelDecode'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -668,12 +668,12 @@ class ABINetLabelDecode(NRTRLabelDecode): ...@@ -668,12 +668,12 @@ class ABINetLabelDecode(NRTRLabelDecode):
dict_character = ['</s>'] + dict_character dict_character = ['</s>'] + dict_character
return dict_character return dict_character
class SPINAttnLabelDecode(AttnLabelDecode): class SPINLabelDecode(AttnLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
def __init__(self, character_dict_path=None, use_space_char=False, def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs): **kwargs):
super(SPINAttnLabelDecode, self).__init__(character_dict_path, super(SPINLabelDecode, self).__init__(character_dict_path,
use_space_char) use_space_char)
def add_special_char(self, dict_character): def add_special_char(self, dict_character):
......
...@@ -61,7 +61,7 @@ Loss: ...@@ -61,7 +61,7 @@ Loss:
ignore_index: 0 ignore_index: 0
PostProcess: PostProcess:
name: SPINAttnLabelDecode name: SPINLabelDecode
use_space_char: False use_space_char: False
...@@ -79,7 +79,7 @@ Train: ...@@ -79,7 +79,7 @@ Train:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- SPINAttnLabelEncode: # Class handling label - SPINLabelEncode: # Class handling label
- SPINRecResizeImg: - SPINRecResizeImg:
image_shape: [100, 32] image_shape: [100, 32]
interpolation : 2 interpolation : 2
...@@ -102,7 +102,7 @@ Eval: ...@@ -102,7 +102,7 @@ Eval:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- SPINAttnLabelEncode: # Class handling label - SPINLabelEncode: # Class handling label
- SPINRecResizeImg: - SPINRecResizeImg:
image_shape: [100, 32] image_shape: [100, 32]
interpolation : 2 interpolation : 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册