Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleOCR
提交
91f6f243
P
PaddleOCR
项目概览
s920243400
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
91f6f243
编写于
8月 27, 2020
作者:
T
tink2123
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
formate code
上级
5edb619c
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
9 addition
and
568 deletion
+9
-568
ppocr/modeling/heads/self_attention/model.py
ppocr/modeling/heads/self_attention/model.py
+9
-568
未找到文件。
ppocr/modeling/heads/self_attention/model.py
浏览文件 @
91f6f243
...
@@ -7,6 +7,11 @@ import paddle.fluid.layers as layers
...
@@ -7,6 +7,11 @@ import paddle.fluid.layers as layers
# Set seed for CE
# Set seed for CE
dropout_seed
=
None
dropout_seed
=
None
encoder_data_input_fields
=
(
"src_word"
,
"src_pos"
,
"src_slf_attn_bias"
,
)
def
wrap_layer_with_block
(
layer
,
block_idx
):
def
wrap_layer_with_block
(
layer
,
block_idx
):
"""
"""
...
@@ -45,25 +50,6 @@ def wrap_layer_with_block(layer, block_idx):
...
@@ -45,25 +50,6 @@ def wrap_layer_with_block(layer, block_idx):
return
layer_wrapper
return
layer_wrapper
def
position_encoding_init
(
n_position
,
d_pos_vec
):
"""
Generate the initial values for the sinusoid position encoding table.
"""
channels
=
d_pos_vec
position
=
np
.
arange
(
n_position
)
num_timescales
=
channels
//
2
log_timescale_increment
=
(
np
.
log
(
float
(
1e4
)
/
float
(
1
))
/
(
num_timescales
-
1
))
inv_timescales
=
np
.
exp
(
np
.
arange
(
num_timescales
))
*
-
log_timescale_increment
scaled_time
=
np
.
expand_dims
(
position
,
1
)
*
np
.
expand_dims
(
inv_timescales
,
0
)
signal
=
np
.
concatenate
([
np
.
sin
(
scaled_time
),
np
.
cos
(
scaled_time
)],
axis
=
1
)
signal
=
np
.
pad
(
signal
,
[[
0
,
0
],
[
0
,
np
.
mod
(
channels
,
2
)]],
'constant'
)
position_enc
=
signal
return
position_enc
.
astype
(
"float32"
)
def
multi_head_attention
(
queries
,
def
multi_head_attention
(
queries
,
keys
,
keys
,
values
,
values
,
...
@@ -271,7 +257,7 @@ post_process_layer = pre_post_process_layer
...
@@ -271,7 +257,7 @@ post_process_layer = pre_post_process_layer
def
prepare_encoder
(
def
prepare_encoder
(
src_word
,
#[b,t,c]
src_word
,
#
[b,t,c]
src_pos
,
src_pos
,
src_vocab_size
,
src_vocab_size
,
src_emb_dim
,
src_emb_dim
,
...
@@ -286,7 +272,7 @@ def prepare_encoder(
...
@@ -286,7 +272,7 @@ def prepare_encoder(
This module is used at the bottom of the encoder stacks.
This module is used at the bottom of the encoder stacks.
"""
"""
src_word_emb
=
src_word
#layers.concat(res,axis=1)
src_word_emb
=
src_word
#
layers.concat(res,axis=1)
src_word_emb
=
layers
.
cast
(
src_word_emb
,
'float32'
)
src_word_emb
=
layers
.
cast
(
src_word_emb
,
'float32'
)
# print("src_word_emb",src_word_emb)
# print("src_word_emb",src_word_emb)
...
@@ -338,12 +324,6 @@ def prepare_decoder(src_word,
...
@@ -338,12 +324,6 @@ def prepare_decoder(src_word,
is_test
=
False
)
if
dropout_rate
else
enc_input
is_test
=
False
)
if
dropout_rate
else
enc_input
# prepare_encoder = partial(
# prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0])
# prepare_decoder = partial(
# prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[1])
def
encoder_layer
(
enc_input
,
def
encoder_layer
(
enc_input
,
attn_bias
,
attn_bias
,
n_head
,
n_head
,
...
@@ -412,234 +392,6 @@ def encoder(enc_input,
...
@@ -412,234 +392,6 @@ def encoder(enc_input,
return
enc_output
return
enc_output
def
decoder_layer
(
dec_input
,
enc_output
,
slf_attn_bias
,
dec_enc_attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
cache
=
None
,
gather_idx
=
None
):
""" The layer to be stacked in decoder part.
The structure of this module is similar to that in the encoder part except
a multi-head attention is added to implement encoder-decoder attention.
"""
slf_attn_output
=
multi_head_attention
(
pre_process_layer
(
dec_input
,
preprocess_cmd
,
prepostprocess_dropout
),
None
,
None
,
slf_attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
,
attention_dropout
,
cache
=
cache
,
gather_idx
=
gather_idx
)
slf_attn_output
=
post_process_layer
(
dec_input
,
slf_attn_output
,
postprocess_cmd
,
prepostprocess_dropout
,
)
enc_attn_output
=
multi_head_attention
(
pre_process_layer
(
slf_attn_output
,
preprocess_cmd
,
prepostprocess_dropout
),
enc_output
,
enc_output
,
dec_enc_attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
,
attention_dropout
,
cache
=
cache
,
gather_idx
=
gather_idx
,
static_kv
=
True
)
enc_attn_output
=
post_process_layer
(
slf_attn_output
,
enc_attn_output
,
postprocess_cmd
,
prepostprocess_dropout
,
)
ffd_output
=
positionwise_feed_forward
(
pre_process_layer
(
enc_attn_output
,
preprocess_cmd
,
prepostprocess_dropout
),
d_inner_hid
,
d_model
,
relu_dropout
,
)
dec_output
=
post_process_layer
(
enc_attn_output
,
ffd_output
,
postprocess_cmd
,
prepostprocess_dropout
,
)
return
dec_output
def
decoder
(
dec_input
,
enc_output
,
dec_slf_attn_bias
,
dec_enc_attn_bias
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
caches
=
None
,
gather_idx
=
None
):
"""
The decoder is composed of a stack of identical decoder_layer layers.
"""
for
i
in
range
(
n_layer
):
dec_output
=
decoder_layer
(
dec_input
,
enc_output
,
dec_slf_attn_bias
,
dec_enc_attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
cache
=
None
if
caches
is
None
else
caches
[
i
],
gather_idx
=
gather_idx
)
dec_input
=
dec_output
dec_output
=
pre_process_layer
(
dec_output
,
preprocess_cmd
,
prepostprocess_dropout
)
return
dec_output
def
make_all_inputs
(
input_fields
):
"""
Define the input data layers for the transformer model.
"""
inputs
=
[]
for
input_field
in
input_fields
:
input_var
=
layers
.
data
(
name
=
input_field
,
shape
=
input_descs
[
input_field
][
0
],
dtype
=
input_descs
[
input_field
][
1
],
lod_level
=
input_descs
[
input_field
][
2
]
if
len
(
input_descs
[
input_field
])
==
3
else
0
,
append_batch_size
=
False
)
inputs
.
append
(
input_var
)
return
inputs
def
make_all_py_reader_inputs
(
input_fields
,
is_test
=
False
):
reader
=
layers
.
py_reader
(
capacity
=
20
,
name
=
"test_reader"
if
is_test
else
"train_reader"
,
shapes
=
[
input_descs
[
input_field
][
0
]
for
input_field
in
input_fields
],
dtypes
=
[
input_descs
[
input_field
][
1
]
for
input_field
in
input_fields
],
lod_levels
=
[
input_descs
[
input_field
][
2
]
if
len
(
input_descs
[
input_field
])
==
3
else
0
for
input_field
in
input_fields
])
return
layers
.
read_file
(
reader
),
reader
def
transformer
(
src_vocab_size
,
trg_vocab_size
,
max_length
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
label_smooth_eps
,
bos_idx
=
0
,
use_py_reader
=
False
,
is_test
=
False
):
if
weight_sharing
:
assert
src_vocab_size
==
trg_vocab_size
,
(
"Vocabularies in source and target should be same for weight sharing."
)
data_input_names
=
encoder_data_input_fields
+
\
decoder_data_input_fields
[:
-
1
]
+
label_data_input_fields
if
use_py_reader
:
all_inputs
,
reader
=
make_all_py_reader_inputs
(
data_input_names
,
is_test
)
else
:
all_inputs
=
make_all_inputs
(
data_input_names
)
# print("all inputs",all_inputs)
enc_inputs_len
=
len
(
encoder_data_input_fields
)
dec_inputs_len
=
len
(
decoder_data_input_fields
[:
-
1
])
enc_inputs
=
all_inputs
[
0
:
enc_inputs_len
]
dec_inputs
=
all_inputs
[
enc_inputs_len
:
enc_inputs_len
+
dec_inputs_len
]
label
=
all_inputs
[
-
2
]
weights
=
all_inputs
[
-
1
]
enc_output
=
wrap_encoder
(
src_vocab_size
,
64
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
enc_inputs
)
predict
=
wrap_decoder
(
trg_vocab_size
,
max_length
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
dec_inputs
,
enc_output
,
)
# Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss.
if
label_smooth_eps
:
label
=
layers
.
label_smooth
(
label
=
layers
.
one_hot
(
input
=
label
,
depth
=
trg_vocab_size
),
epsilon
=
label_smooth_eps
)
cost
=
layers
.
softmax_with_cross_entropy
(
logits
=
predict
,
label
=
label
,
soft_label
=
True
if
label_smooth_eps
else
False
)
weighted_cost
=
cost
*
weights
sum_cost
=
layers
.
reduce_sum
(
weighted_cost
)
token_num
=
layers
.
reduce_sum
(
weights
)
token_num
.
stop_gradient
=
True
avg_cost
=
sum_cost
/
token_num
return
sum_cost
,
avg_cost
,
predict
,
token_num
,
reader
if
use_py_reader
else
None
def
wrap_encoder_forFeature
(
src_vocab_size
,
def
wrap_encoder_forFeature
(
src_vocab_size
,
max_length
,
max_length
,
n_layer
,
n_layer
,
...
@@ -662,44 +414,8 @@ def wrap_encoder_forFeature(src_vocab_size,
...
@@ -662,44 +414,8 @@ def wrap_encoder_forFeature(src_vocab_size,
img
img
"""
"""
if
enc_inputs
is
None
:
# This is used to implement independent encoder program in inference.
conv_features
,
src_pos
,
src_slf_attn_bias
=
make_all_inputs
(
encoder_data_input_fields
)
else
:
conv_features
,
src_pos
,
src_slf_attn_bias
=
enc_inputs
#
conv_features
,
src_pos
,
src_slf_attn_bias
=
enc_inputs
#
b
,
t
,
c
=
conv_features
.
shape
b
,
t
,
c
=
conv_features
.
shape
#"""
# insert cnn
#"""
#import basemodel
# feat = basemodel.resnet_50(img)
# mycrnn = basemodel.CRNN()
# feat = mycrnn.ocr_convs(img,use_cudnn=TrainTaskConfig.use_gpu)
# b, c, w, h = feat.shape
# src_word = layers.reshape(feat, shape=[-1, c, w * h])
#myconv8 = basemodel.conv8()
#feat = myconv8.net(img )
#b , c, h, w = feat.shape#h=6
#print(feat)
#layers.Print(feat,message="conv_feat",summarize=10)
#feat =layers.conv2d(feat,c,filter_size =[4 , 1],act="relu")
#feat = layers.pool2d(feat,pool_stride=(3,1),pool_size=(3,1))
#src_word = layers.squeeze(feat,axes=[2]) #src_word [-1,c,ww]
#feat = layers.transpose(feat, [0,3,1,2])
#src_word = layers.reshape(feat,[-1,w, c*h])
#src_word = layers.im2sequence(
# input=feat,
# stride=[1, 1],
# filter_size=[feat.shape[2], 1])
#layers.Print(src_word,message="src_word",summarize=10)
# print('feat',feat)
#print("src_word",src_word)
enc_input
=
prepare_encoder
(
enc_input
=
prepare_encoder
(
conv_features
,
conv_features
,
...
@@ -755,37 +471,7 @@ def wrap_encoder(src_vocab_size,
...
@@ -755,37 +471,7 @@ def wrap_encoder(src_vocab_size,
encoder_data_input_fields
)
encoder_data_input_fields
)
else
:
else
:
src_word
,
src_pos
,
src_slf_attn_bias
=
enc_inputs
#
src_word
,
src_pos
,
src_slf_attn_bias
=
enc_inputs
#
#"""
# insert cnn
#"""
#import basemodel
# feat = basemodel.resnet_50(img)
# mycrnn = basemodel.CRNN()
# feat = mycrnn.ocr_convs(img,use_cudnn=TrainTaskConfig.use_gpu)
# b, c, w, h = feat.shape
# src_word = layers.reshape(feat, shape=[-1, c, w * h])
#myconv8 = basemodel.conv8()
#feat = myconv8.net(img )
#b , c, h, w = feat.shape#h=6
#print(feat)
#layers.Print(feat,message="conv_feat",summarize=10)
#feat =layers.conv2d(feat,c,filter_size =[4 , 1],act="relu")
#feat = layers.pool2d(feat,pool_stride=(3,1),pool_size=(3,1))
#src_word = layers.squeeze(feat,axes=[2]) #src_word [-1,c,ww]
#feat = layers.transpose(feat, [0,3,1,2])
#src_word = layers.reshape(feat,[-1,w, c*h])
#src_word = layers.im2sequence(
# input=feat,
# stride=[1, 1],
# filter_size=[feat.shape[2], 1])
#layers.Print(src_word,message="src_word",summarize=10)
# print('feat',feat)
#print("src_word",src_word)
enc_input
=
prepare_decoder
(
enc_input
=
prepare_decoder
(
src_word
,
src_word
,
src_pos
,
src_pos
,
...
@@ -811,248 +497,3 @@ def wrap_encoder(src_vocab_size,
...
@@ -811,248 +497,3 @@ def wrap_encoder(src_vocab_size,
preprocess_cmd
,
preprocess_cmd
,
postprocess_cmd
,
)
postprocess_cmd
,
)
return
enc_output
return
enc_output
def
wrap_decoder
(
trg_vocab_size
,
max_length
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
dec_inputs
=
None
,
enc_output
=
None
,
caches
=
None
,
gather_idx
=
None
,
bos_idx
=
0
):
"""
The wrapper assembles together all needed layers for the decoder.
"""
if
dec_inputs
is
None
:
# This is used to implement independent decoder program in inference.
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
enc_output
=
\
make_all_inputs
(
decoder_data_input_fields
)
else
:
trg_word
,
trg_pos
,
trg_slf_attn_bias
,
trg_src_attn_bias
=
dec_inputs
dec_input
=
prepare_decoder
(
trg_word
,
trg_pos
,
trg_vocab_size
,
d_model
,
max_length
,
prepostprocess_dropout
,
bos_idx
=
bos_idx
,
word_emb_param_name
=
"src_word_emb_table"
if
weight_sharing
else
"trg_word_emb_table"
)
dec_output
=
decoder
(
dec_input
,
enc_output
,
trg_slf_attn_bias
,
trg_src_attn_bias
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
caches
=
caches
,
gather_idx
=
gather_idx
)
return
dec_output
# Reshape to 2D tensor to use GEMM instead of BatchedGEMM
dec_output
=
layers
.
reshape
(
dec_output
,
shape
=
[
-
1
,
dec_output
.
shape
[
-
1
]],
inplace
=
True
)
if
weight_sharing
:
predict
=
layers
.
matmul
(
x
=
dec_output
,
y
=
fluid
.
default_main_program
().
global_block
().
var
(
"trg_word_emb_table"
),
transpose_y
=
True
)
else
:
predict
=
layers
.
fc
(
input
=
dec_output
,
size
=
trg_vocab_size
,
bias_attr
=
False
)
if
dec_inputs
is
None
:
# Return probs for independent decoder program.
predict
=
layers
.
softmax
(
predict
)
return
predict
def
fast_decode
(
src_vocab_size
,
trg_vocab_size
,
max_in_len
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
beam_size
,
max_out_len
,
bos_idx
,
eos_idx
,
use_py_reader
=
False
):
"""
Use beam search to decode. Caches will be used to store states of history
steps which can make the decoding faster.
"""
data_input_names
=
encoder_data_input_fields
+
fast_decoder_data_input_fields
if
use_py_reader
:
all_inputs
,
reader
=
make_all_py_reader_inputs
(
data_input_names
)
else
:
all_inputs
=
make_all_inputs
(
data_input_names
)
enc_inputs_len
=
len
(
encoder_data_input_fields
)
dec_inputs_len
=
len
(
fast_decoder_data_input_fields
)
enc_inputs
=
all_inputs
[
0
:
enc_inputs_len
]
#enc_inputs tensor
dec_inputs
=
all_inputs
[
enc_inputs_len
:
enc_inputs_len
+
dec_inputs_len
]
#dec_inputs tensor
enc_output
=
wrap_encoder
(
src_vocab_size
,
64
,
##to do !!!!!????
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
enc_inputs
,
bos_idx
=
bos_idx
)
start_tokens
,
init_scores
,
parent_idx
,
trg_src_attn_bias
=
dec_inputs
def
beam_search
():
max_len
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
start_tokens
.
dtype
,
value
=
max_out_len
,
force_cpu
=
True
)
step_idx
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
start_tokens
.
dtype
,
value
=
0
,
force_cpu
=
True
)
cond
=
layers
.
less_than
(
x
=
step_idx
,
y
=
max_len
)
# default force_cpu=True
while_op
=
layers
.
While
(
cond
)
# array states will be stored for each step.
ids
=
layers
.
array_write
(
layers
.
reshape
(
start_tokens
,
(
-
1
,
1
)),
step_idx
)
scores
=
layers
.
array_write
(
init_scores
,
step_idx
)
# cell states will be overwrited at each step.
# caches contains states of history steps in decoder self-attention
# and static encoder output projections in encoder-decoder attention
# to reduce redundant computation.
caches
=
[
{
"k"
:
# for self attention
layers
.
fill_constant_batch_size_like
(
input
=
start_tokens
,
shape
=
[
-
1
,
n_head
,
0
,
d_key
],
dtype
=
enc_output
.
dtype
,
value
=
0
),
"v"
:
# for self attention
layers
.
fill_constant_batch_size_like
(
input
=
start_tokens
,
shape
=
[
-
1
,
n_head
,
0
,
d_value
],
dtype
=
enc_output
.
dtype
,
value
=
0
),
"static_k"
:
# for encoder-decoder attention
layers
.
create_tensor
(
dtype
=
enc_output
.
dtype
),
"static_v"
:
# for encoder-decoder attention
layers
.
create_tensor
(
dtype
=
enc_output
.
dtype
)
}
for
i
in
range
(
n_layer
)
]
with
while_op
.
block
():
pre_ids
=
layers
.
array_read
(
array
=
ids
,
i
=
step_idx
)
# Since beam_search_op dosen't enforce pre_ids' shape, we can do
# inplace reshape here which actually change the shape of pre_ids.
pre_ids
=
layers
.
reshape
(
pre_ids
,
(
-
1
,
1
,
1
),
inplace
=
True
)
pre_scores
=
layers
.
array_read
(
array
=
scores
,
i
=
step_idx
)
# gather cell states corresponding to selected parent
pre_src_attn_bias
=
layers
.
gather
(
trg_src_attn_bias
,
index
=
parent_idx
)
pre_pos
=
layers
.
elementwise_mul
(
x
=
layers
.
fill_constant_batch_size_like
(
input
=
pre_src_attn_bias
,
# cann't use lod tensor here
value
=
1
,
shape
=
[
-
1
,
1
,
1
],
dtype
=
pre_ids
.
dtype
),
y
=
step_idx
,
axis
=
0
)
logits
=
wrap_decoder
(
trg_vocab_size
,
max_in_len
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
preprocess_cmd
,
postprocess_cmd
,
weight_sharing
,
dec_inputs
=
(
pre_ids
,
pre_pos
,
None
,
pre_src_attn_bias
),
enc_output
=
enc_output
,
caches
=
caches
,
gather_idx
=
parent_idx
,
bos_idx
=
bos_idx
)
# intra-beam topK
topk_scores
,
topk_indices
=
layers
.
topk
(
input
=
layers
.
softmax
(
logits
),
k
=
beam_size
)
accu_scores
=
layers
.
elementwise_add
(
x
=
layers
.
log
(
topk_scores
),
y
=
pre_scores
,
axis
=
0
)
# beam_search op uses lod to differentiate branches.
accu_scores
=
layers
.
lod_reset
(
accu_scores
,
pre_ids
)
# topK reduction across beams, also contain special handle of
# end beams and end sentences(batch reduction)
selected_ids
,
selected_scores
,
gather_idx
=
layers
.
beam_search
(
pre_ids
=
pre_ids
,
pre_scores
=
pre_scores
,
ids
=
topk_indices
,
scores
=
accu_scores
,
beam_size
=
beam_size
,
end_id
=
eos_idx
,
return_parent_idx
=
True
)
layers
.
increment
(
x
=
step_idx
,
value
=
1.0
,
in_place
=
True
)
# cell states(caches) have been updated in wrap_decoder,
# only need to update beam search states here.
layers
.
array_write
(
selected_ids
,
i
=
step_idx
,
array
=
ids
)
layers
.
array_write
(
selected_scores
,
i
=
step_idx
,
array
=
scores
)
layers
.
assign
(
gather_idx
,
parent_idx
)
layers
.
assign
(
pre_src_attn_bias
,
trg_src_attn_bias
)
length_cond
=
layers
.
less_than
(
x
=
step_idx
,
y
=
max_len
)
finish_cond
=
layers
.
logical_not
(
layers
.
is_empty
(
x
=
selected_ids
))
layers
.
logical_and
(
x
=
length_cond
,
y
=
finish_cond
,
out
=
cond
)
finished_ids
,
finished_scores
=
layers
.
beam_search_decode
(
ids
,
scores
,
beam_size
=
beam_size
,
end_id
=
eos_idx
)
return
finished_ids
,
finished_scores
finished_ids
,
finished_scores
=
beam_search
()
return
finished_ids
,
finished_scores
,
reader
if
use_py_reader
else
None
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录