Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
30d908b6
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
30d908b6
编写于
10月 13, 2021
作者:
X
xiaoting
提交者:
GitHub
10月 13, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4316 from Topdu/release/2.3
pick fix nrtr export inference model from drgraph to release/2.3
上级
6c6f19d8
309a4758
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
193 addition
and
166 deletion
+193
-166
configs/rec/rec_mtb_nrtr.yml
configs/rec/rec_mtb_nrtr.yml
+3
-3
ppocr/data/imaug/label_ops.py
ppocr/data/imaug/label_ops.py
+6
-1
ppocr/data/imaug/rec_img_aug.py
ppocr/data/imaug/rec_img_aug.py
+22
-1
ppocr/data/simple_dataset.py
ppocr/data/simple_dataset.py
+0
-1
ppocr/modeling/backbones/rec_nrtr_mtb.py
ppocr/modeling/backbones/rec_nrtr_mtb.py
+5
-3
ppocr/modeling/heads/multiheadAttention.py
ppocr/modeling/heads/multiheadAttention.py
+37
-52
ppocr/modeling/heads/rec_nrtr_head.py
ppocr/modeling/heads/rec_nrtr_head.py
+83
-101
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+13
-1
tools/export_model.py
tools/export_model.py
+2
-0
tools/infer/predict_rec.py
tools/infer/predict_rec.py
+22
-3
未找到文件。
configs/rec/rec_mtb_nrtr.yml
浏览文件 @
30d908b6
...
@@ -46,7 +46,7 @@ Architecture:
...
@@ -46,7 +46,7 @@ Architecture:
name
:
Transformer
name
:
Transformer
d_model
:
512
d_model
:
512
num_encoder_layers
:
6
num_encoder_layers
:
6
beam_size
:
10
# When Beam size is greater than 0, it means to use beam search when evaluation.
beam_size
:
-1
# When Beam size is greater than 0, it means to use beam search when evaluation.
Loss
:
Loss
:
...
@@ -65,7 +65,7 @@ Train:
...
@@ -65,7 +65,7 @@ Train:
name
:
LMDBDataSet
name
:
LMDBDataSet
data_dir
:
./train_data/data_lmdb_release/training/
data_dir
:
./train_data/data_lmdb_release/training/
transforms
:
transforms
:
-
NRTR
DecodeImage
:
# load image
-
DecodeImage
:
# load image
img_mode
:
BGR
img_mode
:
BGR
channel_first
:
False
channel_first
:
False
-
NRTRLabelEncode
:
# Class handling label
-
NRTRLabelEncode
:
# Class handling label
...
@@ -85,7 +85,7 @@ Eval:
...
@@ -85,7 +85,7 @@ Eval:
name
:
LMDBDataSet
name
:
LMDBDataSet
data_dir
:
./train_data/data_lmdb_release/evaluation/
data_dir
:
./train_data/data_lmdb_release/evaluation/
transforms
:
transforms
:
-
NRTR
DecodeImage
:
# load image
-
DecodeImage
:
# load image
img_mode
:
BGR
img_mode
:
BGR
channel_first
:
False
channel_first
:
False
-
NRTRLabelEncode
:
# Class handling label
-
NRTRLabelEncode
:
# Class handling label
...
...
ppocr/data/imaug/label_ops.py
浏览文件 @
30d908b6
...
@@ -174,21 +174,26 @@ class NRTRLabelEncode(BaseRecLabelEncode):
...
@@ -174,21 +174,26 @@ class NRTRLabelEncode(BaseRecLabelEncode):
super
(
NRTRLabelEncode
,
super
(
NRTRLabelEncode
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
self
).
__init__
(
max_text_length
,
character_dict_path
,
character_type
,
use_space_char
)
character_type
,
use_space_char
)
def
__call__
(
self
,
data
):
def
__call__
(
self
,
data
):
text
=
data
[
'label'
]
text
=
data
[
'label'
]
text
=
self
.
encode
(
text
)
text
=
self
.
encode
(
text
)
if
text
is
None
:
if
text
is
None
:
return
None
return
None
if
len
(
text
)
>=
self
.
max_text_len
-
1
:
return
None
data
[
'length'
]
=
np
.
array
(
len
(
text
))
data
[
'length'
]
=
np
.
array
(
len
(
text
))
text
.
insert
(
0
,
2
)
text
.
insert
(
0
,
2
)
text
.
append
(
3
)
text
.
append
(
3
)
text
=
text
+
[
0
]
*
(
self
.
max_text_len
-
len
(
text
))
text
=
text
+
[
0
]
*
(
self
.
max_text_len
-
len
(
text
))
data
[
'label'
]
=
np
.
array
(
text
)
data
[
'label'
]
=
np
.
array
(
text
)
return
data
return
data
def
add_special_char
(
self
,
dict_character
):
def
add_special_char
(
self
,
dict_character
):
dict_character
=
[
'blank'
,
'<unk>'
,
'<s>'
,
'</s>'
]
+
dict_character
dict_character
=
[
'blank'
,
'<unk>'
,
'<s>'
,
'</s>'
]
+
dict_character
return
dict_character
return
dict_character
class
CTCLabelEncode
(
BaseRecLabelEncode
):
class
CTCLabelEncode
(
BaseRecLabelEncode
):
""" Convert between text-label and text-index """
""" Convert between text-label and text-index """
...
...
ppocr/data/imaug/rec_img_aug.py
浏览文件 @
30d908b6
...
@@ -44,12 +44,33 @@ class ClsResizeImg(object):
...
@@ -44,12 +44,33 @@ class ClsResizeImg(object):
class
NRTRRecResizeImg
(
object
):
class
NRTRRecResizeImg
(
object
):
def
__init__
(
self
,
image_shape
,
resize_type
,
**
kwargs
):
def
__init__
(
self
,
image_shape
,
resize_type
,
padding
=
False
,
**
kwargs
):
self
.
image_shape
=
image_shape
self
.
image_shape
=
image_shape
self
.
resize_type
=
resize_type
self
.
resize_type
=
resize_type
self
.
padding
=
padding
def
__call__
(
self
,
data
):
def
__call__
(
self
,
data
):
img
=
data
[
'image'
]
img
=
data
[
'image'
]
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
image_shape
=
self
.
image_shape
if
self
.
padding
:
imgC
,
imgH
,
imgW
=
image_shape
# todo: change to 0 and modified image shape
h
=
img
.
shape
[
0
]
w
=
img
.
shape
[
1
]
ratio
=
w
/
float
(
h
)
if
math
.
ceil
(
imgH
*
ratio
)
>
imgW
:
resized_w
=
imgW
else
:
resized_w
=
int
(
math
.
ceil
(
imgH
*
ratio
))
resized_image
=
cv2
.
resize
(
img
,
(
resized_w
,
imgH
))
norm_img
=
np
.
expand_dims
(
resized_image
,
-
1
)
norm_img
=
norm_img
.
transpose
((
2
,
0
,
1
))
resized_image
=
norm_img
.
astype
(
np
.
float32
)
/
128.
-
1.
padding_im
=
np
.
zeros
((
imgC
,
imgH
,
imgW
),
dtype
=
np
.
float32
)
padding_im
[:,
:,
0
:
resized_w
]
=
resized_image
data
[
'image'
]
=
padding_im
return
data
if
self
.
resize_type
==
'PIL'
:
if
self
.
resize_type
==
'PIL'
:
image_pil
=
Image
.
fromarray
(
np
.
uint8
(
img
))
image_pil
=
Image
.
fromarray
(
np
.
uint8
(
img
))
img
=
image_pil
.
resize
(
self
.
image_shape
,
Image
.
ANTIALIAS
)
img
=
image_pil
.
resize
(
self
.
image_shape
,
Image
.
ANTIALIAS
)
...
...
ppocr/data/simple_dataset.py
浏览文件 @
30d908b6
...
@@ -15,7 +15,6 @@ import numpy as np
...
@@ -15,7 +15,6 @@ import numpy as np
import
os
import
os
import
random
import
random
from
paddle.io
import
Dataset
from
paddle.io
import
Dataset
from
.imaug
import
transform
,
create_operators
from
.imaug
import
transform
,
create_operators
...
...
ppocr/modeling/backbones/rec_nrtr_mtb.py
浏览文件 @
30d908b6
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
from
paddle
import
nn
from
paddle
import
nn
import
paddle
class
MTB
(
nn
.
Layer
):
class
MTB
(
nn
.
Layer
):
...
@@ -40,7 +41,8 @@ class MTB(nn.Layer):
...
@@ -40,7 +41,8 @@ class MTB(nn.Layer):
x
=
self
.
block
(
images
)
x
=
self
.
block
(
images
)
if
self
.
cnn_num
==
2
:
if
self
.
cnn_num
==
2
:
# (b, w, h, c)
# (b, w, h, c)
x
=
x
.
transpose
([
0
,
3
,
2
,
1
])
x
=
paddle
.
transpose
(
x
,
[
0
,
3
,
2
,
1
])
x_shape
=
x
.
shape
x_shape
=
paddle
.
shape
(
x
)
x
=
x
.
reshape
([
x_shape
[
0
],
x_shape
[
1
],
x_shape
[
2
]
*
x_shape
[
3
]])
x
=
paddle
.
reshape
(
x
,
[
x_shape
[
0
],
x_shape
[
1
],
x_shape
[
2
]
*
x_shape
[
3
]])
return
x
return
x
ppocr/modeling/heads/multiheadAttention.py
浏览文件 @
30d908b6
...
@@ -71,8 +71,6 @@ class MultiheadAttention(nn.Layer):
...
@@ -71,8 +71,6 @@ class MultiheadAttention(nn.Layer):
value
,
value
,
key_padding_mask
=
None
,
key_padding_mask
=
None
,
incremental_state
=
None
,
incremental_state
=
None
,
need_weights
=
True
,
static_kv
=
False
,
attn_mask
=
None
):
attn_mask
=
None
):
"""
"""
Inputs of forward function
Inputs of forward function
...
@@ -88,46 +86,42 @@ class MultiheadAttention(nn.Layer):
...
@@ -88,46 +86,42 @@ class MultiheadAttention(nn.Layer):
attn_output: [target length, batch size, embed dim]
attn_output: [target length, batch size, embed dim]
attn_output_weights: [batch size, target length, sequence length]
attn_output_weights: [batch size, target length, sequence length]
"""
"""
tgt_len
,
bsz
,
embed_dim
=
query
.
shape
q_shape
=
paddle
.
shape
(
query
)
assert
embed_dim
==
self
.
embed_dim
src_shape
=
paddle
.
shape
(
key
)
assert
list
(
query
.
shape
)
==
[
tgt_len
,
bsz
,
embed_dim
]
assert
key
.
shape
==
value
.
shape
q
=
self
.
_in_proj_q
(
query
)
q
=
self
.
_in_proj_q
(
query
)
k
=
self
.
_in_proj_k
(
key
)
k
=
self
.
_in_proj_k
(
key
)
v
=
self
.
_in_proj_v
(
value
)
v
=
self
.
_in_proj_v
(
value
)
q
*=
self
.
scaling
q
*=
self
.
scaling
q
=
paddle
.
transpose
(
q
=
q
.
reshape
([
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
]).
transpose
(
paddle
.
reshape
(
[
1
,
0
,
2
])
q
,
[
q_shape
[
0
],
q_shape
[
1
],
self
.
num_heads
,
self
.
head_dim
]),
k
=
k
.
reshape
([
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
]).
transpose
(
[
1
,
2
,
0
,
3
])
[
1
,
0
,
2
])
k
=
paddle
.
transpose
(
v
=
v
.
reshape
([
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
]).
transpose
(
paddle
.
reshape
(
[
1
,
0
,
2
])
k
,
[
src_shape
[
0
],
q_shape
[
1
],
self
.
num_heads
,
self
.
head_dim
]),
[
1
,
2
,
0
,
3
])
src_len
=
k
.
shape
[
1
]
v
=
paddle
.
transpose
(
paddle
.
reshape
(
v
,
[
src_shape
[
0
],
q_shape
[
1
],
self
.
num_heads
,
self
.
head_dim
]),
[
1
,
2
,
0
,
3
])
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
shape
[
0
]
==
bsz
assert
key_padding_mask
.
shape
[
0
]
==
q_shape
[
1
]
assert
key_padding_mask
.
shape
[
1
]
==
src_len
assert
key_padding_mask
.
shape
[
1
]
==
src_shape
[
0
]
attn_output_weights
=
paddle
.
matmul
(
q
,
attn_output_weights
=
paddle
.
bmm
(
q
,
k
.
transpose
([
0
,
2
,
1
]))
paddle
.
transpose
(
k
,
[
0
,
1
,
3
,
2
]))
assert
list
(
attn_output_weights
.
shape
)
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
if
attn_mask
is
not
None
:
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
unsqueeze
(
0
)
attn_mask
=
paddle
.
unsqueeze
(
paddle
.
unsqueeze
(
attn_mask
,
0
),
0
)
attn_output_weights
+=
attn_mask
attn_output_weights
+=
attn_mask
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
attn_output_weights
=
attn_output_weights
.
reshape
(
attn_output_weights
=
paddle
.
reshape
(
[
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
])
attn_output_weights
,
key
=
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
).
astype
(
'float32'
)
[
q_shape
[
1
],
self
.
num_heads
,
q_shape
[
0
],
src_shape
[
0
]])
y
=
paddle
.
full
(
shape
=
key
.
shape
,
dtype
=
'float32'
,
fill_value
=
'-inf'
)
key
=
paddle
.
unsqueeze
(
paddle
.
unsqueeze
(
key_padding_mask
,
1
),
2
)
key
=
paddle
.
cast
(
key
,
'float32'
)
y
=
paddle
.
full
(
shape
=
paddle
.
shape
(
key
),
dtype
=
'float32'
,
fill_value
=
'-inf'
)
y
=
paddle
.
where
(
key
==
0.
,
key
,
y
)
y
=
paddle
.
where
(
key
==
0.
,
key
,
y
)
attn_output_weights
+=
y
attn_output_weights
+=
y
attn_output_weights
=
attn_output_weights
.
reshape
(
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
])
attn_output_weights
=
F
.
softmax
(
attn_output_weights
=
F
.
softmax
(
attn_output_weights
.
astype
(
'float32'
),
attn_output_weights
.
astype
(
'float32'
),
axis
=-
1
,
axis
=-
1
,
...
@@ -136,43 +130,34 @@ class MultiheadAttention(nn.Layer):
...
@@ -136,43 +130,34 @@ class MultiheadAttention(nn.Layer):
attn_output_weights
=
F
.
dropout
(
attn_output_weights
=
F
.
dropout
(
attn_output_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_output_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_output
=
paddle
.
bmm
(
attn_output_weights
,
v
)
attn_output
=
paddle
.
matmul
(
attn_output_weights
,
v
)
assert
list
(
attn_output
.
attn_output
=
paddle
.
reshape
(
shape
)
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
paddle
.
transpose
(
attn_output
,
[
2
,
0
,
1
,
3
]),
attn_output
=
attn_output
.
transpose
([
1
,
0
,
2
]).
reshape
(
[
q_shape
[
0
],
q_shape
[
1
],
self
.
embed_dim
])
[
tgt_len
,
bsz
,
embed_dim
])
attn_output
=
self
.
out_proj
(
attn_output
)
attn_output
=
self
.
out_proj
(
attn_output
)
if
need_weights
:
return
attn_output
# average attention weights over heads
attn_output_weights
=
attn_output_weights
.
reshape
(
[
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
])
attn_output_weights
=
attn_output_weights
.
sum
(
axis
=
1
)
/
self
.
num_heads
else
:
attn_output_weights
=
None
return
attn_output
,
attn_output_weights
def
_in_proj_q
(
self
,
query
):
def
_in_proj_q
(
self
,
query
):
query
=
query
.
transpose
(
[
1
,
2
,
0
])
query
=
paddle
.
transpose
(
query
,
[
1
,
2
,
0
])
query
=
paddle
.
unsqueeze
(
query
,
axis
=
2
)
query
=
paddle
.
unsqueeze
(
query
,
axis
=
2
)
res
=
self
.
conv1
(
query
)
res
=
self
.
conv1
(
query
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
res
.
transpose
(
[
2
,
0
,
1
])
res
=
paddle
.
transpose
(
res
,
[
2
,
0
,
1
])
return
res
return
res
def
_in_proj_k
(
self
,
key
):
def
_in_proj_k
(
self
,
key
):
key
=
key
.
transpose
(
[
1
,
2
,
0
])
key
=
paddle
.
transpose
(
key
,
[
1
,
2
,
0
])
key
=
paddle
.
unsqueeze
(
key
,
axis
=
2
)
key
=
paddle
.
unsqueeze
(
key
,
axis
=
2
)
res
=
self
.
conv2
(
key
)
res
=
self
.
conv2
(
key
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
res
.
transpose
(
[
2
,
0
,
1
])
res
=
paddle
.
transpose
(
res
,
[
2
,
0
,
1
])
return
res
return
res
def
_in_proj_v
(
self
,
value
):
def
_in_proj_v
(
self
,
value
):
value
=
value
.
transpose
(
[
1
,
2
,
0
])
#(1, 2, 0)
value
=
paddle
.
transpose
(
value
,
[
1
,
2
,
0
])
#(1, 2, 0)
value
=
paddle
.
unsqueeze
(
value
,
axis
=
2
)
value
=
paddle
.
unsqueeze
(
value
,
axis
=
2
)
res
=
self
.
conv3
(
value
)
res
=
self
.
conv3
(
value
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
paddle
.
squeeze
(
res
,
axis
=
2
)
res
=
res
.
transpose
(
[
2
,
0
,
1
])
res
=
paddle
.
transpose
(
res
,
[
2
,
0
,
1
])
return
res
return
res
ppocr/modeling/heads/rec_nrtr_head.py
浏览文件 @
30d908b6
...
@@ -61,12 +61,12 @@ class Transformer(nn.Layer):
...
@@ -61,12 +61,12 @@ class Transformer(nn.Layer):
custom_decoder
=
None
,
custom_decoder
=
None
,
in_channels
=
0
,
in_channels
=
0
,
out_channels
=
0
,
out_channels
=
0
,
dst_vocab_size
=
99
,
scale_embedding
=
True
):
scale_embedding
=
True
):
super
(
Transformer
,
self
).
__init__
()
super
(
Transformer
,
self
).
__init__
()
self
.
out_channels
=
out_channels
+
1
self
.
embedding
=
Embeddings
(
self
.
embedding
=
Embeddings
(
d_model
=
d_model
,
d_model
=
d_model
,
vocab
=
dst_vocab_size
,
vocab
=
self
.
out_channels
,
padding_idx
=
0
,
padding_idx
=
0
,
scale_embedding
=
scale_embedding
)
scale_embedding
=
scale_embedding
)
self
.
positional_encoding
=
PositionalEncoding
(
self
.
positional_encoding
=
PositionalEncoding
(
...
@@ -96,9 +96,10 @@ class Transformer(nn.Layer):
...
@@ -96,9 +96,10 @@ class Transformer(nn.Layer):
self
.
beam_size
=
beam_size
self
.
beam_size
=
beam_size
self
.
d_model
=
d_model
self
.
d_model
=
d_model
self
.
nhead
=
nhead
self
.
nhead
=
nhead
self
.
tgt_word_prj
=
nn
.
Linear
(
d_model
,
dst_vocab_size
,
bias_attr
=
False
)
self
.
tgt_word_prj
=
nn
.
Linear
(
d_model
,
self
.
out_channels
,
bias_attr
=
False
)
w0
=
np
.
random
.
normal
(
0.0
,
d_model
**-
0.5
,
w0
=
np
.
random
.
normal
(
0.0
,
d_model
**-
0.5
,
(
d_model
,
dst_vocab_size
)).
astype
(
np
.
float32
)
(
d_model
,
self
.
out_channels
)).
astype
(
np
.
float32
)
self
.
tgt_word_prj
.
weight
.
set_value
(
w0
)
self
.
tgt_word_prj
.
weight
.
set_value
(
w0
)
self
.
apply
(
self
.
_init_weights
)
self
.
apply
(
self
.
_init_weights
)
...
@@ -156,46 +157,41 @@ class Transformer(nn.Layer):
...
@@ -156,46 +157,41 @@ class Transformer(nn.Layer):
return
self
.
forward_test
(
src
)
return
self
.
forward_test
(
src
)
def
forward_test
(
self
,
src
):
def
forward_test
(
self
,
src
):
bs
=
src
.
shape
[
0
]
bs
=
paddle
.
shape
(
src
)
[
0
]
if
self
.
encoder
is
not
None
:
if
self
.
encoder
is
not
None
:
src
=
self
.
positional_encoding
(
src
.
transpose
(
[
1
,
0
,
2
]))
src
=
self
.
positional_encoding
(
paddle
.
transpose
(
src
,
[
1
,
0
,
2
]))
memory
=
self
.
encoder
(
src
)
memory
=
self
.
encoder
(
src
)
else
:
else
:
memory
=
src
.
squeeze
(
2
).
transpose
(
[
2
,
0
,
1
])
memory
=
paddle
.
transpose
(
paddle
.
squeeze
(
src
,
2
),
[
2
,
0
,
1
])
dec_seq
=
paddle
.
full
((
bs
,
1
),
2
,
dtype
=
paddle
.
int64
)
dec_seq
=
paddle
.
full
((
bs
,
1
),
2
,
dtype
=
paddle
.
int64
)
dec_prob
=
paddle
.
full
((
bs
,
1
),
1.
,
dtype
=
paddle
.
float32
)
for
len_dec_seq
in
range
(
1
,
25
):
for
len_dec_seq
in
range
(
1
,
25
):
src_enc
=
memory
.
clone
()
dec_seq_embed
=
paddle
.
transpose
(
self
.
embedding
(
dec_seq
),
[
1
,
0
,
2
])
tgt_key_padding_mask
=
self
.
generate_padding_mask
(
dec_seq
)
dec_seq_embed
=
self
.
embedding
(
dec_seq
).
transpose
([
1
,
0
,
2
])
dec_seq_embed
=
self
.
positional_encoding
(
dec_seq_embed
)
dec_seq_embed
=
self
.
positional_encoding
(
dec_seq_embed
)
tgt_mask
=
self
.
generate_square_subsequent_mask
(
dec_seq_embed
.
shape
[
tgt_mask
=
self
.
generate_square_subsequent_mask
(
0
])
paddle
.
shape
(
dec_seq_embed
)[
0
])
output
=
self
.
decoder
(
output
=
self
.
decoder
(
dec_seq_embed
,
dec_seq_embed
,
src_enc
,
memory
,
tgt_mask
=
tgt_mask
,
tgt_mask
=
tgt_mask
,
memory_mask
=
None
,
memory_mask
=
None
,
tgt_key_padding_mask
=
tgt_key_padding_mask
,
tgt_key_padding_mask
=
None
,
memory_key_padding_mask
=
None
)
memory_key_padding_mask
=
None
)
dec_output
=
output
.
transpose
([
1
,
0
,
2
])
dec_output
=
paddle
.
transpose
(
output
,
[
1
,
0
,
2
])
dec_output
=
dec_output
[:,
-
1
,
:]
dec_output
=
dec_output
[:,
word_prob
=
F
.
softmax
(
self
.
tgt_word_prj
(
dec_output
),
axis
=
1
)
-
1
,
:]
# Pick the last step: (bh * bm) * d_h
preds_idx
=
paddle
.
argmax
(
word_prob
,
axis
=
1
)
word_prob
=
F
.
log_softmax
(
self
.
tgt_word_prj
(
dec_output
),
axis
=
1
)
word_prob
=
word_prob
.
reshape
([
1
,
bs
,
-
1
])
preds_idx
=
word_prob
.
argmax
(
axis
=
2
)
if
paddle
.
equal_all
(
if
paddle
.
equal_all
(
preds_idx
[
-
1
]
,
preds_idx
,
paddle
.
full
(
paddle
.
full
(
p
reds_idx
[
-
1
].
shape
,
3
,
dtype
=
'int64'
)):
p
addle
.
shape
(
preds_idx
)
,
3
,
dtype
=
'int64'
)):
break
break
preds_prob
=
paddle
.
max
(
word_prob
,
axis
=
1
)
preds_prob
=
word_prob
.
max
(
axis
=
2
)
dec_seq
=
paddle
.
concat
(
dec_seq
=
paddle
.
concat
(
[
dec_seq
,
preds_idx
.
reshape
([
-
1
,
1
])],
axis
=
1
)
[
dec_seq
,
paddle
.
reshape
(
preds_idx
,
[
-
1
,
1
])],
axis
=
1
)
dec_prob
=
paddle
.
concat
(
return
dec_seq
[
dec_prob
,
paddle
.
reshape
(
preds_prob
,
[
-
1
,
1
])],
axis
=
1
)
return
[
dec_seq
,
dec_prob
]
def
forward_beam
(
self
,
images
):
def
forward_beam
(
self
,
images
):
''' Translation work in one batch '''
''' Translation work in one batch '''
...
@@ -211,14 +207,15 @@ class Transformer(nn.Layer):
...
@@ -211,14 +207,15 @@ class Transformer(nn.Layer):
n_prev_active_inst
,
n_bm
):
n_prev_active_inst
,
n_bm
):
''' Collect tensor parts associated to active instances. '''
''' Collect tensor parts associated to active instances. '''
_
,
*
d_hs
=
beamed_tensor
.
shape
beamed_tensor_shape
=
paddle
.
shape
(
beamed_tensor
)
n_curr_active_inst
=
len
(
curr_active_inst_idx
)
n_curr_active_inst
=
len
(
curr_active_inst_idx
)
new_shape
=
(
n_curr_active_inst
*
n_bm
,
*
d_hs
)
new_shape
=
(
n_curr_active_inst
*
n_bm
,
beamed_tensor_shape
[
1
],
beamed_tensor_shape
[
2
])
beamed_tensor
=
beamed_tensor
.
reshape
([
n_prev_active_inst
,
-
1
])
beamed_tensor
=
beamed_tensor
.
reshape
([
n_prev_active_inst
,
-
1
])
beamed_tensor
=
beamed_tensor
.
index_select
(
beamed_tensor
=
beamed_tensor
.
index_select
(
paddle
.
to_tensor
(
curr_active_inst_idx
)
,
axis
=
0
)
curr_active_inst_idx
,
axis
=
0
)
beamed_tensor
=
beamed_tensor
.
reshape
(
[
*
new_shape
]
)
beamed_tensor
=
beamed_tensor
.
reshape
(
new_shape
)
return
beamed_tensor
return
beamed_tensor
...
@@ -249,44 +246,26 @@ class Transformer(nn.Layer):
...
@@ -249,44 +246,26 @@ class Transformer(nn.Layer):
b
.
get_current_state
()
for
b
in
inst_dec_beams
if
not
b
.
done
b
.
get_current_state
()
for
b
in
inst_dec_beams
if
not
b
.
done
]
]
dec_partial_seq
=
paddle
.
stack
(
dec_partial_seq
)
dec_partial_seq
=
paddle
.
stack
(
dec_partial_seq
)
dec_partial_seq
=
dec_partial_seq
.
reshape
([
-
1
,
len_dec_seq
])
dec_partial_seq
=
dec_partial_seq
.
reshape
([
-
1
,
len_dec_seq
])
return
dec_partial_seq
return
dec_partial_seq
def
prepare_beam_memory_key_padding_mask
(
inst_dec_beams
,
memory_key_padding_mask
,
n_bm
):
keep
=
[]
for
idx
in
(
memory_key_padding_mask
):
if
not
inst_dec_beams
[
idx
].
done
:
keep
.
append
(
idx
)
memory_key_padding_mask
=
memory_key_padding_mask
[
paddle
.
to_tensor
(
keep
)]
len_s
=
memory_key_padding_mask
.
shape
[
-
1
]
n_inst
=
memory_key_padding_mask
.
shape
[
0
]
memory_key_padding_mask
=
paddle
.
concat
(
[
memory_key_padding_mask
for
i
in
range
(
n_bm
)],
axis
=
1
)
memory_key_padding_mask
=
memory_key_padding_mask
.
reshape
(
[
n_inst
*
n_bm
,
len_s
])
#repeat(1, n_bm)
return
memory_key_padding_mask
def
predict_word
(
dec_seq
,
enc_output
,
n_active_inst
,
n_bm
,
def
predict_word
(
dec_seq
,
enc_output
,
n_active_inst
,
n_bm
,
memory_key_padding_mask
):
memory_key_padding_mask
):
tgt_key_padding_mask
=
self
.
generate_padding_mask
(
dec_seq
)
dec_seq
=
paddle
.
transpose
(
self
.
embedding
(
dec_seq
),
[
1
,
0
,
2
])
dec_seq
=
self
.
embedding
(
dec_seq
).
transpose
([
1
,
0
,
2
])
dec_seq
=
self
.
positional_encoding
(
dec_seq
)
dec_seq
=
self
.
positional_encoding
(
dec_seq
)
tgt_mask
=
self
.
generate_square_subsequent_mask
(
dec_seq
.
shape
[
tgt_mask
=
self
.
generate_square_subsequent_mask
(
0
])
paddle
.
shape
(
dec_seq
)[
0
])
dec_output
=
self
.
decoder
(
dec_output
=
self
.
decoder
(
dec_seq
,
dec_seq
,
enc_output
,
enc_output
,
tgt_mask
=
tgt_mask
,
tgt_mask
=
tgt_mask
,
tgt_key_padding_mask
=
tgt_key_padding_mask
,
tgt_key_padding_mask
=
None
,
memory_key_padding_mask
=
memory_key_padding_mask
,
memory_key_padding_mask
=
memory_key_padding_mask
,
)
).
transpose
(
[
1
,
0
,
2
])
dec_output
=
paddle
.
transpose
(
dec_output
,
[
1
,
0
,
2
])
dec_output
=
dec_output
[:,
dec_output
=
dec_output
[:,
-
1
,
:]
# Pick the last step: (bh * bm) * d_h
-
1
,
:]
# Pick the last step: (bh * bm) * d_h
word_prob
=
F
.
log_
softmax
(
self
.
tgt_word_prj
(
dec_output
),
axis
=
1
)
word_prob
=
F
.
softmax
(
self
.
tgt_word_prj
(
dec_output
),
axis
=
1
)
word_prob
=
word_prob
.
reshape
(
[
n_active_inst
,
n_bm
,
-
1
])
word_prob
=
paddle
.
reshape
(
word_prob
,
[
n_active_inst
,
n_bm
,
-
1
])
return
word_prob
return
word_prob
def
collect_active_inst_idx_list
(
inst_beams
,
word_prob
,
def
collect_active_inst_idx_list
(
inst_beams
,
word_prob
,
...
@@ -302,9 +281,8 @@ class Transformer(nn.Layer):
...
@@ -302,9 +281,8 @@ class Transformer(nn.Layer):
n_active_inst
=
len
(
inst_idx_to_position_map
)
n_active_inst
=
len
(
inst_idx_to_position_map
)
dec_seq
=
prepare_beam_dec_seq
(
inst_dec_beams
,
len_dec_seq
)
dec_seq
=
prepare_beam_dec_seq
(
inst_dec_beams
,
len_dec_seq
)
memory_key_padding_mask
=
None
word_prob
=
predict_word
(
dec_seq
,
enc_output
,
n_active_inst
,
n_bm
,
word_prob
=
predict_word
(
dec_seq
,
enc_output
,
n_active_inst
,
n_bm
,
memory_key_padding_mask
)
None
)
# Update the beam with predicted word prob information and collect incomplete instances
# Update the beam with predicted word prob information and collect incomplete instances
active_inst_idx_list
=
collect_active_inst_idx_list
(
active_inst_idx_list
=
collect_active_inst_idx_list
(
inst_dec_beams
,
word_prob
,
inst_idx_to_position_map
)
inst_dec_beams
,
word_prob
,
inst_idx_to_position_map
)
...
@@ -324,27 +302,21 @@ class Transformer(nn.Layer):
...
@@ -324,27 +302,21 @@ class Transformer(nn.Layer):
with
paddle
.
no_grad
():
with
paddle
.
no_grad
():
#-- Encode
#-- Encode
if
self
.
encoder
is
not
None
:
if
self
.
encoder
is
not
None
:
src
=
self
.
positional_encoding
(
images
.
transpose
([
1
,
0
,
2
]))
src
=
self
.
positional_encoding
(
images
.
transpose
([
1
,
0
,
2
]))
src_enc
=
self
.
encoder
(
src
)
.
transpose
([
1
,
0
,
2
])
src_enc
=
self
.
encoder
(
src
)
else
:
else
:
src_enc
=
images
.
squeeze
(
2
).
transpose
([
0
,
2
,
1
])
src_enc
=
images
.
squeeze
(
2
).
transpose
([
0
,
2
,
1
])
#-- Repeat data for beam search
n_bm
=
self
.
beam_size
n_bm
=
self
.
beam_size
n_inst
,
len_s
,
d_h
=
src_enc
.
shape
src_shape
=
paddle
.
shape
(
src_enc
)
src_enc
=
paddle
.
concat
([
src_enc
for
i
in
range
(
n_bm
)],
axis
=
1
)
inst_dec_beams
=
[
Beam
(
n_bm
)
for
_
in
range
(
1
)]
src_enc
=
src_enc
.
reshape
([
n_inst
*
n_bm
,
len_s
,
d_h
]).
transpose
(
active_inst_idx_list
=
list
(
range
(
1
))
[
1
,
0
,
2
])
# Repeat data for beam search
#-- Prepare beams
src_enc
=
paddle
.
tile
(
src_enc
,
[
1
,
n_bm
,
1
])
inst_dec_beams
=
[
Beam
(
n_bm
)
for
_
in
range
(
n_inst
)]
#-- Bookkeeping for active or not
active_inst_idx_list
=
list
(
range
(
n_inst
))
inst_idx_to_position_map
=
get_inst_idx_to_tensor_position_map
(
inst_idx_to_position_map
=
get_inst_idx_to_tensor_position_map
(
active_inst_idx_list
)
active_inst_idx_list
)
#
--
Decode
# Decode
for
len_dec_seq
in
range
(
1
,
25
):
for
len_dec_seq
in
range
(
1
,
25
):
src_enc_copy
=
src_enc
.
clone
()
src_enc_copy
=
src_enc
.
clone
()
active_inst_idx_list
=
beam_decode_step
(
active_inst_idx_list
=
beam_decode_step
(
...
@@ -358,10 +330,19 @@ class Transformer(nn.Layer):
...
@@ -358,10 +330,19 @@ class Transformer(nn.Layer):
batch_hyp
,
batch_scores
=
collect_hypothesis_and_scores
(
inst_dec_beams
,
batch_hyp
,
batch_scores
=
collect_hypothesis_and_scores
(
inst_dec_beams
,
1
)
1
)
result_hyp
=
[]
result_hyp
=
[]
for
bs_hyp
in
batch_hyp
:
hyp_scores
=
[]
bs_hyp_pad
=
bs_hyp
[
0
]
+
[
3
]
*
(
25
-
len
(
bs_hyp
[
0
]))
for
bs_hyp
,
score
in
zip
(
batch_hyp
,
batch_scores
):
l
=
len
(
bs_hyp
[
0
])
bs_hyp_pad
=
bs_hyp
[
0
]
+
[
3
]
*
(
25
-
l
)
result_hyp
.
append
(
bs_hyp_pad
)
result_hyp
.
append
(
bs_hyp_pad
)
return
paddle
.
to_tensor
(
np
.
array
(
result_hyp
),
dtype
=
paddle
.
int64
)
score
=
float
(
score
)
/
l
hyp_score
=
[
score
for
_
in
range
(
25
)]
hyp_scores
.
append
(
hyp_score
)
return
[
paddle
.
to_tensor
(
np
.
array
(
result_hyp
),
dtype
=
paddle
.
int64
),
paddle
.
to_tensor
(
hyp_scores
)
]
def
generate_square_subsequent_mask
(
self
,
sz
):
def
generate_square_subsequent_mask
(
self
,
sz
):
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
...
@@ -376,7 +357,7 @@ class Transformer(nn.Layer):
...
@@ -376,7 +357,7 @@ class Transformer(nn.Layer):
return
mask
return
mask
def
generate_padding_mask
(
self
,
x
):
def
generate_padding_mask
(
self
,
x
):
padding_mask
=
x
.
equal
(
paddle
.
to_tensor
(
0
,
dtype
=
x
.
dtype
))
padding_mask
=
paddle
.
equal
(
x
,
paddle
.
to_tensor
(
0
,
dtype
=
x
.
dtype
))
return
padding_mask
return
padding_mask
def
_reset_parameters
(
self
):
def
_reset_parameters
(
self
):
...
@@ -514,17 +495,17 @@ class TransformerEncoderLayer(nn.Layer):
...
@@ -514,17 +495,17 @@ class TransformerEncoderLayer(nn.Layer):
src
,
src
,
src
,
src
,
attn_mask
=
src_mask
,
attn_mask
=
src_mask
,
key_padding_mask
=
src_key_padding_mask
)
[
0
]
key_padding_mask
=
src_key_padding_mask
)
src
=
src
+
self
.
dropout1
(
src2
)
src
=
src
+
self
.
dropout1
(
src2
)
src
=
self
.
norm1
(
src
)
src
=
self
.
norm1
(
src
)
src
=
src
.
transpose
(
[
1
,
2
,
0
])
src
=
paddle
.
transpose
(
src
,
[
1
,
2
,
0
])
src
=
paddle
.
unsqueeze
(
src
,
2
)
src
=
paddle
.
unsqueeze
(
src
,
2
)
src2
=
self
.
conv2
(
F
.
relu
(
self
.
conv1
(
src
)))
src2
=
self
.
conv2
(
F
.
relu
(
self
.
conv1
(
src
)))
src2
=
paddle
.
squeeze
(
src2
,
2
)
src2
=
paddle
.
squeeze
(
src2
,
2
)
src2
=
src2
.
transpose
(
[
2
,
0
,
1
])
src2
=
paddle
.
transpose
(
src2
,
[
2
,
0
,
1
])
src
=
paddle
.
squeeze
(
src
,
2
)
src
=
paddle
.
squeeze
(
src
,
2
)
src
=
src
.
transpose
(
[
2
,
0
,
1
])
src
=
paddle
.
transpose
(
src
,
[
2
,
0
,
1
])
src
=
src
+
self
.
dropout2
(
src2
)
src
=
src
+
self
.
dropout2
(
src2
)
src
=
self
.
norm2
(
src
)
src
=
self
.
norm2
(
src
)
...
@@ -598,7 +579,7 @@ class TransformerDecoderLayer(nn.Layer):
...
@@ -598,7 +579,7 @@ class TransformerDecoderLayer(nn.Layer):
tgt
,
tgt
,
tgt
,
tgt
,
attn_mask
=
tgt_mask
,
attn_mask
=
tgt_mask
,
key_padding_mask
=
tgt_key_padding_mask
)
[
0
]
key_padding_mask
=
tgt_key_padding_mask
)
tgt
=
tgt
+
self
.
dropout1
(
tgt2
)
tgt
=
tgt
+
self
.
dropout1
(
tgt2
)
tgt
=
self
.
norm1
(
tgt
)
tgt
=
self
.
norm1
(
tgt
)
tgt2
=
self
.
multihead_attn
(
tgt2
=
self
.
multihead_attn
(
...
@@ -606,18 +587,18 @@ class TransformerDecoderLayer(nn.Layer):
...
@@ -606,18 +587,18 @@ class TransformerDecoderLayer(nn.Layer):
memory
,
memory
,
memory
,
memory
,
attn_mask
=
memory_mask
,
attn_mask
=
memory_mask
,
key_padding_mask
=
memory_key_padding_mask
)
[
0
]
key_padding_mask
=
memory_key_padding_mask
)
tgt
=
tgt
+
self
.
dropout2
(
tgt2
)
tgt
=
tgt
+
self
.
dropout2
(
tgt2
)
tgt
=
self
.
norm2
(
tgt
)
tgt
=
self
.
norm2
(
tgt
)
# default
# default
tgt
=
tgt
.
transpose
(
[
1
,
2
,
0
])
tgt
=
paddle
.
transpose
(
tgt
,
[
1
,
2
,
0
])
tgt
=
paddle
.
unsqueeze
(
tgt
,
2
)
tgt
=
paddle
.
unsqueeze
(
tgt
,
2
)
tgt2
=
self
.
conv2
(
F
.
relu
(
self
.
conv1
(
tgt
)))
tgt2
=
self
.
conv2
(
F
.
relu
(
self
.
conv1
(
tgt
)))
tgt2
=
paddle
.
squeeze
(
tgt2
,
2
)
tgt2
=
paddle
.
squeeze
(
tgt2
,
2
)
tgt2
=
tgt2
.
transpose
(
[
2
,
0
,
1
])
tgt2
=
paddle
.
transpose
(
tgt2
,
[
2
,
0
,
1
])
tgt
=
paddle
.
squeeze
(
tgt
,
2
)
tgt
=
paddle
.
squeeze
(
tgt
,
2
)
tgt
=
tgt
.
transpose
(
[
2
,
0
,
1
])
tgt
=
paddle
.
transpose
(
tgt
,
[
2
,
0
,
1
])
tgt
=
tgt
+
self
.
dropout3
(
tgt2
)
tgt
=
tgt
+
self
.
dropout3
(
tgt2
)
tgt
=
self
.
norm3
(
tgt
)
tgt
=
self
.
norm3
(
tgt
)
...
@@ -656,8 +637,8 @@ class PositionalEncoding(nn.Layer):
...
@@ -656,8 +637,8 @@ class PositionalEncoding(nn.Layer):
(
-
math
.
log
(
10000.0
)
/
dim
))
(
-
math
.
log
(
10000.0
)
/
dim
))
pe
[:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
pe
[:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
pe
[:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
pe
[:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
pe
=
p
e
.
unsqueeze
(
0
)
pe
=
p
addle
.
unsqueeze
(
pe
,
0
)
pe
=
p
e
.
transpose
(
[
1
,
0
,
2
])
pe
=
p
addle
.
transpose
(
pe
,
[
1
,
0
,
2
])
self
.
register_buffer
(
'pe'
,
pe
)
self
.
register_buffer
(
'pe'
,
pe
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -670,7 +651,7 @@ class PositionalEncoding(nn.Layer):
...
@@ -670,7 +651,7 @@ class PositionalEncoding(nn.Layer):
Examples:
Examples:
>>> output = pos_encoder(x)
>>> output = pos_encoder(x)
"""
"""
x
=
x
+
self
.
pe
[:
x
.
shape
[
0
],
:]
x
=
x
+
self
.
pe
[:
paddle
.
shape
(
x
)
[
0
],
:]
return
self
.
dropout
(
x
)
return
self
.
dropout
(
x
)
...
@@ -702,7 +683,7 @@ class PositionalEncoding_2d(nn.Layer):
...
@@ -702,7 +683,7 @@ class PositionalEncoding_2d(nn.Layer):
(
-
math
.
log
(
10000.0
)
/
dim
))
(
-
math
.
log
(
10000.0
)
/
dim
))
pe
[:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
pe
[:,
0
::
2
]
=
paddle
.
sin
(
position
*
div_term
)
pe
[:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
pe
[:,
1
::
2
]
=
paddle
.
cos
(
position
*
div_term
)
pe
=
p
e
.
unsqueeze
(
0
).
transpose
(
[
1
,
0
,
2
])
pe
=
p
addle
.
transpose
(
paddle
.
unsqueeze
(
pe
,
0
),
[
1
,
0
,
2
])
self
.
register_buffer
(
'pe'
,
pe
)
self
.
register_buffer
(
'pe'
,
pe
)
self
.
avg_pool_1
=
nn
.
AdaptiveAvgPool2D
((
1
,
1
))
self
.
avg_pool_1
=
nn
.
AdaptiveAvgPool2D
((
1
,
1
))
...
@@ -722,22 +703,23 @@ class PositionalEncoding_2d(nn.Layer):
...
@@ -722,22 +703,23 @@ class PositionalEncoding_2d(nn.Layer):
Examples:
Examples:
>>> output = pos_encoder(x)
>>> output = pos_encoder(x)
"""
"""
w_pe
=
self
.
pe
[:
x
.
shape
[
-
1
],
:]
w_pe
=
self
.
pe
[:
paddle
.
shape
(
x
)
[
-
1
],
:]
w1
=
self
.
linear1
(
self
.
avg_pool_1
(
x
).
squeeze
()).
unsqueeze
(
0
)
w1
=
self
.
linear1
(
self
.
avg_pool_1
(
x
).
squeeze
()).
unsqueeze
(
0
)
w_pe
=
w_pe
*
w1
w_pe
=
w_pe
*
w1
w_pe
=
w_pe
.
transpose
(
[
1
,
2
,
0
])
w_pe
=
paddle
.
transpose
(
w_pe
,
[
1
,
2
,
0
])
w_pe
=
w_pe
.
unsqueeze
(
2
)
w_pe
=
paddle
.
unsqueeze
(
w_pe
,
2
)
h_pe
=
self
.
pe
[:
x
.
shape
[
-
2
],
:]
h_pe
=
self
.
pe
[:
paddle
.
shape
(
x
)
.
shape
[
-
2
],
:]
w2
=
self
.
linear2
(
self
.
avg_pool_2
(
x
).
squeeze
()).
unsqueeze
(
0
)
w2
=
self
.
linear2
(
self
.
avg_pool_2
(
x
).
squeeze
()).
unsqueeze
(
0
)
h_pe
=
h_pe
*
w2
h_pe
=
h_pe
*
w2
h_pe
=
h_pe
.
transpose
(
[
1
,
2
,
0
])
h_pe
=
paddle
.
transpose
(
h_pe
,
[
1
,
2
,
0
])
h_pe
=
h_pe
.
unsqueeze
(
3
)
h_pe
=
paddle
.
unsqueeze
(
h_pe
,
3
)
x
=
x
+
w_pe
+
h_pe
x
=
x
+
w_pe
+
h_pe
x
=
x
.
reshape
(
x
=
paddle
.
transpose
(
[
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
2
]
*
x
.
shape
[
3
]]).
transpose
(
paddle
.
reshape
(
x
,
[
2
,
0
,
1
])
[
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
2
]
*
x
.
shape
[
3
]]),
[
2
,
0
,
1
])
return
self
.
dropout
(
x
)
return
self
.
dropout
(
x
)
...
@@ -817,7 +799,7 @@ class Beam():
...
@@ -817,7 +799,7 @@ class Beam():
def
sort_scores
(
self
):
def
sort_scores
(
self
):
"Sort the scores."
"Sort the scores."
return
self
.
scores
,
paddle
.
to_tensor
(
return
self
.
scores
,
paddle
.
to_tensor
(
[
i
for
i
in
range
(
self
.
scores
.
shape
[
0
]
)],
dtype
=
'int32'
)
[
i
for
i
in
range
(
int
(
self
.
scores
.
shape
[
0
])
)],
dtype
=
'int32'
)
def
get_the_best_score_and_idx
(
self
):
def
get_the_best_score_and_idx
(
self
):
"Get the score of the best in the beam."
"Get the score of the best in the beam."
...
...
ppocr/postprocess/rec_postprocess.py
浏览文件 @
30d908b6
...
@@ -176,7 +176,19 @@ class NRTRLabelDecode(BaseRecLabelDecode):
...
@@ -176,7 +176,19 @@ class NRTRLabelDecode(BaseRecLabelDecode):
else
:
else
:
preds_idx
=
preds
preds_idx
=
preds
text
=
self
.
decode
(
preds_idx
)
if
len
(
preds
)
==
2
:
preds_id
=
preds
[
0
]
preds_prob
=
preds
[
1
]
if
isinstance
(
preds_id
,
paddle
.
Tensor
):
preds_id
=
preds_id
.
numpy
()
if
isinstance
(
preds_prob
,
paddle
.
Tensor
):
preds_prob
=
preds_prob
.
numpy
()
if
preds_id
[
0
][
0
]
==
2
:
preds_idx
=
preds_id
[:,
1
:]
preds_prob
=
preds_prob
[:,
1
:]
else
:
preds_idx
=
preds_id
text
=
self
.
decode
(
preds_idx
,
preds_prob
,
is_remove_duplicate
=
False
)
if
label
is
None
:
if
label
is
None
:
return
text
return
text
label
=
self
.
decode
(
label
[:,
1
:])
label
=
self
.
decode
(
label
[:,
1
:])
...
...
tools/export_model.py
浏览文件 @
30d908b6
...
@@ -60,6 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
...
@@ -60,6 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
)
infer_shape
[
-
1
]
=
100
infer_shape
[
-
1
]
=
100
if
arch_config
[
"algorithm"
]
==
"NRTR"
:
infer_shape
=
[
1
,
32
,
100
]
elif
arch_config
[
"model_type"
]
==
"table"
:
elif
arch_config
[
"model_type"
]
==
"table"
:
infer_shape
=
[
3
,
488
,
488
]
infer_shape
=
[
3
,
488
,
488
]
model
=
to_static
(
model
=
to_static
(
...
...
tools/infer/predict_rec.py
浏览文件 @
30d908b6
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
import
sys
import
sys
from
PIL
import
Image
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
__dir__
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
__dir__
)
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
__dir__
,
'../..'
)))
...
@@ -61,6 +61,13 @@ class TextRecognizer(object):
...
@@ -61,6 +61,13 @@ class TextRecognizer(object):
"character_dict_path"
:
args
.
rec_char_dict_path
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
"use_space_char"
:
args
.
use_space_char
}
}
elif
self
.
rec_algorithm
==
'NRTR'
:
postprocess_params
=
{
'name'
:
'NRTRLabelDecode'
,
"character_type"
:
args
.
rec_char_type
,
"character_dict_path"
:
args
.
rec_char_dict_path
,
"use_space_char"
:
args
.
use_space_char
}
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
postprocess_op
=
build_post_process
(
postprocess_params
)
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
self
.
predictor
,
self
.
input_tensor
,
self
.
output_tensors
,
self
.
config
=
\
utility
.
create_predictor
(
args
,
'rec'
,
logger
)
utility
.
create_predictor
(
args
,
'rec'
,
logger
)
...
@@ -87,6 +94,16 @@ class TextRecognizer(object):
...
@@ -87,6 +94,16 @@ class TextRecognizer(object):
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
def
resize_norm_img
(
self
,
img
,
max_wh_ratio
):
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
imgC
,
imgH
,
imgW
=
self
.
rec_image_shape
if
self
.
rec_algorithm
==
'NRTR'
:
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2GRAY
)
# return padding_im
image_pil
=
Image
.
fromarray
(
np
.
uint8
(
img
))
img
=
image_pil
.
resize
([
100
,
32
],
Image
.
ANTIALIAS
)
img
=
np
.
array
(
img
)
norm_img
=
np
.
expand_dims
(
img
,
-
1
)
norm_img
=
norm_img
.
transpose
((
2
,
0
,
1
))
return
norm_img
.
astype
(
np
.
float32
)
/
128.
-
1.
assert
imgC
==
img
.
shape
[
2
]
assert
imgC
==
img
.
shape
[
2
]
max_wh_ratio
=
max
(
max_wh_ratio
,
imgW
/
imgH
)
max_wh_ratio
=
max
(
max_wh_ratio
,
imgW
/
imgH
)
imgW
=
int
((
32
*
max_wh_ratio
))
imgW
=
int
((
32
*
max_wh_ratio
))
...
@@ -252,14 +269,16 @@ class TextRecognizer(object):
...
@@ -252,14 +269,16 @@ class TextRecognizer(object):
else
:
else
:
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
input_tensor
.
copy_from_cpu
(
norm_img_batch
)
self
.
predictor
.
run
()
self
.
predictor
.
run
()
outputs
=
[]
outputs
=
[]
for
output_tensor
in
self
.
output_tensors
:
for
output_tensor
in
self
.
output_tensors
:
output
=
output_tensor
.
copy_to_cpu
()
output
=
output_tensor
.
copy_to_cpu
()
outputs
.
append
(
output
)
outputs
.
append
(
output
)
if
self
.
benchmark
:
if
self
.
benchmark
:
self
.
autolog
.
times
.
stamp
()
self
.
autolog
.
times
.
stamp
()
preds
=
outputs
[
0
]
if
len
(
outputs
)
!=
1
:
preds
=
outputs
else
:
preds
=
outputs
[
0
]
rec_result
=
self
.
postprocess_op
(
preds
)
rec_result
=
self
.
postprocess_op
(
preds
)
for
rno
in
range
(
len
(
rec_result
)):
for
rno
in
range
(
len
(
rec_result
)):
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
rec_result
[
rno
]
rec_res
[
indices
[
beg_img_no
+
rno
]]
=
rec_result
[
rno
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录