未验证 提交 caf77ddb 编写于 作者: A andyj 提交者: GitHub

Merge pull request #7957 from andyjpaddle/fix_vl_dict

fix visionlan default dict
...@@ -139,7 +139,7 @@ Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.9999493) ...@@ -139,7 +139,7 @@ Predicts of ./doc/imgs_words/en/word_2.png:('yourself', 0.9999493)
## 5. FAQ ## 5. FAQ
1. MJSynth和SynthText两种数据集来自于[VisionLAN源repo](https://github.com/wangyuxin87/VisionLAN) 1. MJSynth和SynthText两种数据集来自于[VisionLAN源repo](https://github.com/wangyuxin87/VisionLAN)
2. 我们使用VisionLAN作者提供的预训练模型进行finetune训练。 2. 我们使用VisionLAN作者提供的预训练模型进行finetune训练,预训练模型配套字典为'ppocr/utils/ic15_dict.txt'
## 引用 ## 引用
......
...@@ -120,7 +120,7 @@ Not supported ...@@ -120,7 +120,7 @@ Not supported
## 5. FAQ ## 5. FAQ
1. Note that the MJSynth and SynthText datasets come from [VisionLAN repo](https://github.com/wangyuxin87/VisionLAN). 1. Note that the MJSynth and SynthText datasets come from [VisionLAN repo](https://github.com/wangyuxin87/VisionLAN).
2. We use the pre-trained model provided by the VisionLAN authors for finetune training. 2. We use the pre-trained model provided by the VisionLAN authors for finetune training. The dictionary for the pre-trained model is 'ppocr/utils/ic15_dict.txt'.
## Citation ## Citation
......
...@@ -107,6 +107,7 @@ class BaseRecLabelEncode(object): ...@@ -107,6 +107,7 @@ class BaseRecLabelEncode(object):
self.beg_str = "sos" self.beg_str = "sos"
self.end_str = "eos" self.end_str = "eos"
self.lower = lower self.lower = lower
self.use_default_dict = False
if character_dict_path is None: if character_dict_path is None:
logger = get_logger() logger = get_logger()
...@@ -116,8 +117,11 @@ class BaseRecLabelEncode(object): ...@@ -116,8 +117,11 @@ class BaseRecLabelEncode(object):
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
self.lower = True self.lower = True
self.use_default_dict = True
else: else:
self.character_str = [] self.character_str = []
if 'ppocr/utils/ic15_dict.txt' in character_dict_path:
self.use_default_dict = True
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
for line in lines: for line in lines:
...@@ -1400,6 +1404,7 @@ class VLLabelEncode(BaseRecLabelEncode): ...@@ -1400,6 +1404,7 @@ class VLLabelEncode(BaseRecLabelEncode):
**kwargs): **kwargs):
super(VLLabelEncode, self).__init__( super(VLLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char, lower) max_text_length, character_dict_path, use_space_char, lower)
if self.use_default_dict:
self.character = self.character[10:] + self.character[ self.character = self.character[10:] + self.character[
1:10] + [self.character[0]] 1:10] + [self.character[0]]
self.dict = {} self.dict = {}
......
...@@ -26,10 +26,15 @@ class BaseRecLabelDecode(object): ...@@ -26,10 +26,15 @@ class BaseRecLabelDecode(object):
self.end_str = "eos" self.end_str = "eos"
self.reverse = False self.reverse = False
self.character_str = [] self.character_str = []
self.use_default_dict = False
if character_dict_path is None: if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
self.use_default_dict = True
else: else:
if 'ppocr/utils/ic15_dict.txt' in character_dict_path:
self.use_default_dict = True
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
for line in lines: for line in lines:
...@@ -805,6 +810,7 @@ class VLLabelDecode(BaseRecLabelDecode): ...@@ -805,6 +810,7 @@ class VLLabelDecode(BaseRecLabelDecode):
super(VLLabelDecode, self).__init__(character_dict_path, use_space_char) super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25) self.max_text_length = kwargs.get('max_text_length', 25)
self.nclass = len(self.character) + 1 self.nclass = len(self.character) + 1
if self.use_default_dict:
self.character = self.character[10:] + self.character[ self.character = self.character[10:] + self.character[
1:10] + [self.character[0]] 1:10] + [self.character[0]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册