未验证 提交 e3cd3433 编写于 作者: X xlg-go 提交者: GitHub

rec_r45_abinet.yml add max_length and image_size (#10744)

* rec_r45_abinet.yml add max_length and image_shape

* image_shape to image_size
上级 66461c33
...@@ -16,7 +16,7 @@ Global: ...@@ -16,7 +16,7 @@ Global:
# for data or label process # for data or label process
character_dict_path: character_dict_path:
character_type: en character_type: en
max_text_length: 25 max_text_length: &max_text_length 25
infer_mode: False infer_mode: False
use_space_char: False use_space_char: False
save_res_path: ./output/rec/predicts_abinet.txt save_res_path: ./output/rec/predicts_abinet.txt
...@@ -45,6 +45,8 @@ Architecture: ...@@ -45,6 +45,8 @@ Architecture:
name: ABINetHead name: ABINetHead
use_lang: True use_lang: True
iter_size: 3 iter_size: 3
max_length: *max_text_length
image_size: [ &h 32, &w 128 ] # [ h, w ]
Loss: Loss:
...@@ -70,7 +72,7 @@ Train: ...@@ -70,7 +72,7 @@ Train:
- ABINetLabelEncode: # Class handling label - ABINetLabelEncode: # Class handling label
ignore_index: *ignore_index ignore_index: *ignore_index
- ABINetRecResizeImg: - ABINetRecResizeImg:
image_shape: [3, 32, 128] image_shape: [3, *h, *w]
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader: loader:
...@@ -90,7 +92,7 @@ Eval: ...@@ -90,7 +92,7 @@ Eval:
- ABINetLabelEncode: # Class handling label - ABINetLabelEncode: # Class handling label
ignore_index: *ignore_index ignore_index: *ignore_index
- ABINetRecResizeImg: - ABINetRecResizeImg:
image_shape: [3, 32, 128] image_shape: [3, *h, *w]
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
loader: loader:
......
...@@ -182,11 +182,13 @@ class ABINetHead(nn.Layer): ...@@ -182,11 +182,13 @@ class ABINetHead(nn.Layer):
dropout=0.1, dropout=0.1,
max_length=25, max_length=25,
use_lang=False, use_lang=False,
iter_size=1): iter_size=1,
image_size=(32, 128)):
super().__init__() super().__init__()
self.max_length = max_length + 1 self.max_length = max_length + 1
h, w = image_size[0] // 4, image_size[1] // 4
self.pos_encoder = PositionalEncoding( self.pos_encoder = PositionalEncoding(
dropout=0.1, dim=d_model, max_len=8 * 32) dropout=0.1, dim=d_model, max_len=h * w)
self.encoder = nn.LayerList([ self.encoder = nn.LayerList([
TransformerBlock( TransformerBlock(
d_model=d_model, d_model=d_model,
...@@ -199,7 +201,7 @@ class ABINetHead(nn.Layer): ...@@ -199,7 +201,7 @@ class ABINetHead(nn.Layer):
]) ])
self.decoder = PositionAttention( self.decoder = PositionAttention(
max_length=max_length + 1, # additional stop token max_length=max_length + 1, # additional stop token
mode='nearest', ) mode='nearest', h=h, w=w)
self.out_channels = out_channels self.out_channels = out_channels
self.cls = nn.Linear(d_model, self.out_channels) self.cls = nn.Linear(d_model, self.out_channels)
self.use_lang = use_lang self.use_lang = use_lang
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册