Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
dad22cfa
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dad22cfa
编写于
9月 25, 2019
作者:
0
0YuanZhang0
提交者:
GitHub
9月 25, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add_nets (#3416)
上级
e7b1fef2
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
584 addition
and
0 deletion
+584
-0
PaddleNLP/Research/MRQA2019-D-NET/server/bert_server/pdnlp/nets/bert.py
...arch/MRQA2019-D-NET/server/bert_server/pdnlp/nets/bert.py
+231
-0
PaddleNLP/Research/MRQA2019-D-NET/server/bert_server/pdnlp/nets/transformer_encoder.py
...-NET/server/bert_server/pdnlp/nets/transformer_encoder.py
+353
-0
未找到文件。
PaddleNLP/Research/MRQA2019-D-NET/server/bert_server/pdnlp/nets/bert.py
0 → 100644
浏览文件 @
dad22cfa
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT model."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
six
import
json
import
numpy
as
np
import
paddle.fluid
as
fluid
from
palm.nets.transformer_encoder
import
encoder
as
encoder
from
palm.nets.transformer_encoder
import
pre_process_layer
as
pre_process_layer
class
BertModel
(
object
):
def
__init__
(
self
,
src_ids
,
position_ids
,
sentence_ids
,
input_mask
,
config
,
weight_sharing
=
True
,
use_fp16
=
False
,
model_name
=
''
):
self
.
_emb_size
=
config
[
"hidden_size"
]
self
.
_n_layer
=
config
[
"num_hidden_layers"
]
self
.
_n_head
=
config
[
"num_attention_heads"
]
self
.
_voc_size
=
config
[
"vocab_size"
]
self
.
_max_position_seq_len
=
config
[
"max_position_embeddings"
]
self
.
_sent_types
=
config
[
"type_vocab_size"
]
self
.
_hidden_act
=
config
[
"hidden_act"
]
self
.
_prepostprocess_dropout
=
config
[
"hidden_dropout_prob"
]
self
.
_attention_dropout
=
config
[
"attention_probs_dropout_prob"
]
self
.
_weight_sharing
=
weight_sharing
self
.
model_name
=
model_name
self
.
_word_emb_name
=
self
.
model_name
+
"word_embedding"
self
.
_pos_emb_name
=
self
.
model_name
+
"pos_embedding"
self
.
_sent_emb_name
=
self
.
model_name
+
"sent_embedding"
self
.
_dtype
=
"float16"
if
use_fp16
else
"float32"
# Initialize all weigths by truncated normal initializer, and all biases
# will be initialized by constant zero by default.
self
.
_param_initializer
=
fluid
.
initializer
.
TruncatedNormal
(
scale
=
config
[
"initializer_range"
])
self
.
_build_model
(
src_ids
,
position_ids
,
sentence_ids
,
input_mask
,
config
)
def
_build_model
(
self
,
src_ids
,
position_ids
,
sentence_ids
,
input_mask
,
config
):
# padding id in vocabulary must be set to 0
emb_out
=
fluid
.
layers
.
embedding
(
input
=
src_ids
,
size
=
[
self
.
_voc_size
,
self
.
_emb_size
],
dtype
=
self
.
_dtype
,
param_attr
=
fluid
.
ParamAttr
(
name
=
self
.
_word_emb_name
,
initializer
=
self
.
_param_initializer
),
is_sparse
=
False
)
self
.
emb_out
=
emb_out
position_emb_out
=
fluid
.
layers
.
embedding
(
input
=
position_ids
,
size
=
[
self
.
_max_position_seq_len
,
self
.
_emb_size
],
dtype
=
self
.
_dtype
,
param_attr
=
fluid
.
ParamAttr
(
name
=
self
.
_pos_emb_name
,
initializer
=
self
.
_param_initializer
))
self
.
position_emb_out
=
position_emb_out
sent_emb_out
=
fluid
.
layers
.
embedding
(
sentence_ids
,
size
=
[
self
.
_sent_types
,
self
.
_emb_size
],
dtype
=
self
.
_dtype
,
param_attr
=
fluid
.
ParamAttr
(
name
=
self
.
_sent_emb_name
,
initializer
=
self
.
_param_initializer
))
self
.
sent_emb_out
=
sent_emb_out
emb_out
=
emb_out
+
position_emb_out
emb_out
=
emb_out
+
sent_emb_out
emb_out
=
pre_process_layer
(
emb_out
,
'nd'
,
self
.
_prepostprocess_dropout
,
name
=
'pre_encoder'
)
if
self
.
_dtype
==
"float16"
:
input_mask
=
fluid
.
layers
.
cast
(
x
=
input_mask
,
dtype
=
self
.
_dtype
)
self_attn_mask
=
fluid
.
layers
.
matmul
(
x
=
input_mask
,
y
=
input_mask
,
transpose_y
=
True
)
self_attn_mask
=
fluid
.
layers
.
scale
(
x
=
self_attn_mask
,
scale
=
config
[
"self_att_scale"
],
bias
=-
1.0
,
bias_after_scale
=
False
)
n_head_self_attn_mask
=
fluid
.
layers
.
stack
(
x
=
[
self_attn_mask
]
*
self
.
_n_head
,
axis
=
1
)
n_head_self_attn_mask
.
stop_gradient
=
True
self
.
_enc_out
=
encoder
(
enc_input
=
emb_out
,
attn_bias
=
n_head_self_attn_mask
,
n_layer
=
self
.
_n_layer
,
n_head
=
self
.
_n_head
,
d_key
=
self
.
_emb_size
//
self
.
_n_head
,
d_value
=
self
.
_emb_size
//
self
.
_n_head
,
d_model
=
self
.
_emb_size
,
d_inner_hid
=
self
.
_emb_size
*
4
,
prepostprocess_dropout
=
self
.
_prepostprocess_dropout
,
attention_dropout
=
self
.
_attention_dropout
,
relu_dropout
=
0
,
hidden_act
=
self
.
_hidden_act
,
preprocess_cmd
=
""
,
postprocess_cmd
=
"dan"
,
param_initializer
=
self
.
_param_initializer
,
name
=
self
.
model_name
+
'encoder'
)
def
get_sequence_output
(
self
):
return
self
.
_enc_out
def
get_pooled_output
(
self
):
"""Get the first feature of each sequence for classification"""
next_sent_feat
=
fluid
.
layers
.
slice
(
input
=
self
.
_enc_out
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
1
])
next_sent_feat
=
fluid
.
layers
.
fc
(
input
=
next_sent_feat
,
size
=
self
.
_emb_size
,
act
=
"tanh"
,
param_attr
=
fluid
.
ParamAttr
(
name
=
self
.
model_name
+
"pooled_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
bias_attr
=
"pooled_fc.b_0"
)
return
next_sent_feat
def
get_pretraining_output
(
self
,
mask_label
,
mask_pos
,
labels
):
"""Get the loss & accuracy for pretraining"""
mask_pos
=
fluid
.
layers
.
cast
(
x
=
mask_pos
,
dtype
=
'int32'
)
# extract the first token feature in each sentence
next_sent_feat
=
self
.
get_pooled_output
()
reshaped_emb_out
=
fluid
.
layers
.
reshape
(
x
=
self
.
_enc_out
,
shape
=
[
-
1
,
self
.
_emb_size
])
# extract masked tokens' feature
mask_feat
=
fluid
.
layers
.
gather
(
input
=
reshaped_emb_out
,
index
=
mask_pos
)
# transform: fc
mask_trans_feat
=
fluid
.
layers
.
fc
(
input
=
mask_feat
,
size
=
self
.
_emb_size
,
act
=
self
.
_hidden_act
,
param_attr
=
fluid
.
ParamAttr
(
name
=
self
.
model_name
+
'mask_lm_trans_fc.w_0'
,
initializer
=
self
.
_param_initializer
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
self
.
model_name
+
'mask_lm_trans_fc.b_0'
))
# transform: layer norm
mask_trans_feat
=
pre_process_layer
(
mask_trans_feat
,
'n'
,
name
=
self
.
model_name
+
'mask_lm_trans'
)
mask_lm_out_bias_attr
=
fluid
.
ParamAttr
(
name
=
self
.
model_name
+
"mask_lm_out_fc.b_0"
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.0
))
if
self
.
_weight_sharing
:
fc_out
=
fluid
.
layers
.
matmul
(
x
=
mask_trans_feat
,
y
=
fluid
.
default_main_program
().
global_block
().
var
(
self
.
_word_emb_name
),
transpose_y
=
True
)
fc_out
+=
fluid
.
layers
.
create_parameter
(
shape
=
[
self
.
_voc_size
],
dtype
=
self
.
_dtype
,
attr
=
mask_lm_out_bias_attr
,
is_bias
=
True
)
else
:
fc_out
=
fluid
.
layers
.
fc
(
input
=
mask_trans_feat
,
size
=
self
.
_voc_size
,
param_attr
=
fluid
.
ParamAttr
(
name
=
self
.
model_name
+
"mask_lm_out_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
bias_attr
=
mask_lm_out_bias_attr
)
mask_lm_loss
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
=
fc_out
,
label
=
mask_label
)
mean_mask_lm_loss
=
fluid
.
layers
.
mean
(
mask_lm_loss
)
next_sent_fc_out
=
fluid
.
layers
.
fc
(
input
=
next_sent_feat
,
size
=
2
,
param_attr
=
fluid
.
ParamAttr
(
name
=
self
.
model_name
+
"next_sent_fc.w_0"
,
initializer
=
self
.
_param_initializer
),
bias_attr
=
self
.
model_name
+
"next_sent_fc.b_0"
)
next_sent_loss
,
next_sent_softmax
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
=
next_sent_fc_out
,
label
=
labels
,
return_softmax
=
True
)
next_sent_acc
=
fluid
.
layers
.
accuracy
(
input
=
next_sent_softmax
,
label
=
labels
)
mean_next_sent_loss
=
fluid
.
layers
.
mean
(
next_sent_loss
)
loss
=
mean_next_sent_loss
+
mean_mask_lm_loss
return
next_sent_acc
,
mean_mask_lm_loss
,
loss
if
__name__
==
"__main__"
:
print
(
"hello wolrd!"
)
PaddleNLP/Research/MRQA2019-D-NET/server/bert_server/pdnlp/nets/transformer_encoder.py
0 → 100644
浏览文件 @
dad22cfa
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer encoder."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
functools
import
partial
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
def
multi_head_attention
(
queries
,
keys
,
values
,
attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
=
1
,
dropout_rate
=
0.
,
cache
=
None
,
param_initializer
=
None
,
name
=
'multi_head_att'
):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys
=
queries
if
keys
is
None
else
keys
values
=
keys
if
values
is
None
else
values
if
not
(
len
(
queries
.
shape
)
==
len
(
keys
.
shape
)
==
len
(
values
.
shape
)
==
3
):
raise
ValueError
(
"Inputs: quries, keys and values should all be 3-D tensors."
)
def
__compute_qkv
(
queries
,
keys
,
values
,
n_head
,
d_key
,
d_value
):
"""
Add linear projection to queries, keys, and values.
"""
q
=
layers
.
fc
(
input
=
queries
,
size
=
d_key
*
n_head
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_query_fc.w_0'
,
initializer
=
param_initializer
),
bias_attr
=
name
+
'_query_fc.b_0'
)
k
=
layers
.
fc
(
input
=
keys
,
size
=
d_key
*
n_head
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_key_fc.w_0'
,
initializer
=
param_initializer
),
bias_attr
=
name
+
'_key_fc.b_0'
)
v
=
layers
.
fc
(
input
=
values
,
size
=
d_value
*
n_head
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_value_fc.w_0'
,
initializer
=
param_initializer
),
bias_attr
=
name
+
'_value_fc.b_0'
)
return
q
,
k
,
v
def
__split_heads
(
x
,
n_head
):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
hidden_size
=
x
.
shape
[
-
1
]
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped
=
layers
.
reshape
(
x
=
x
,
shape
=
[
0
,
0
,
n_head
,
hidden_size
//
n_head
],
inplace
=
True
)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return
layers
.
transpose
(
x
=
reshaped
,
perm
=
[
0
,
2
,
1
,
3
])
def
__combine_heads
(
x
):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if
len
(
x
.
shape
)
==
3
:
return
x
if
len
(
x
.
shape
)
!=
4
:
raise
ValueError
(
"Input(x) should be a 4-D Tensor."
)
trans_x
=
layers
.
transpose
(
x
,
perm
=
[
0
,
2
,
1
,
3
])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return
layers
.
reshape
(
x
=
trans_x
,
shape
=
[
0
,
0
,
trans_x
.
shape
[
2
]
*
trans_x
.
shape
[
3
]],
inplace
=
True
)
def
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_bias
,
d_key
,
dropout_rate
):
"""
Scaled Dot-Product Attention
"""
scaled_q
=
layers
.
scale
(
x
=
q
,
scale
=
d_key
**-
0.5
)
product
=
layers
.
matmul
(
x
=
scaled_q
,
y
=
k
,
transpose_y
=
True
)
if
attn_bias
:
product
+=
attn_bias
weights
=
layers
.
softmax
(
product
)
if
dropout_rate
:
weights
=
layers
.
dropout
(
weights
,
dropout_prob
=
dropout_rate
,
dropout_implementation
=
"upscale_in_train"
,
is_test
=
False
)
out
=
layers
.
matmul
(
weights
,
v
)
return
out
q
,
k
,
v
=
__compute_qkv
(
queries
,
keys
,
values
,
n_head
,
d_key
,
d_value
)
if
cache
is
not
None
:
# use cache and concat time steps
# Since the inplace reshape in __split_heads changes the shape of k and
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k
=
cache
[
"k"
]
=
layers
.
concat
(
[
layers
.
reshape
(
cache
[
"k"
],
shape
=
[
0
,
0
,
d_model
]),
k
],
axis
=
1
)
v
=
cache
[
"v"
]
=
layers
.
concat
(
[
layers
.
reshape
(
cache
[
"v"
],
shape
=
[
0
,
0
,
d_model
]),
v
],
axis
=
1
)
q
=
__split_heads
(
q
,
n_head
)
k
=
__split_heads
(
k
,
n_head
)
v
=
__split_heads
(
v
,
n_head
)
ctx_multiheads
=
scaled_dot_product_attention
(
q
,
k
,
v
,
attn_bias
,
d_key
,
dropout_rate
)
out
=
__combine_heads
(
ctx_multiheads
)
# Project back to the model size.
proj_out
=
layers
.
fc
(
input
=
out
,
size
=
d_model
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_output_fc.w_0'
,
initializer
=
param_initializer
),
bias_attr
=
name
+
'_output_fc.b_0'
)
return
proj_out
def
positionwise_feed_forward
(
x
,
d_inner_hid
,
d_hid
,
dropout_rate
,
hidden_act
,
param_initializer
=
None
,
name
=
'ffn'
):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden
=
layers
.
fc
(
input
=
x
,
size
=
d_inner_hid
,
num_flatten_dims
=
2
,
act
=
hidden_act
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_fc_0.w_0'
,
initializer
=
param_initializer
),
bias_attr
=
name
+
'_fc_0.b_0'
)
if
dropout_rate
:
hidden
=
layers
.
dropout
(
hidden
,
dropout_prob
=
dropout_rate
,
dropout_implementation
=
"upscale_in_train"
,
is_test
=
False
)
out
=
layers
.
fc
(
input
=
hidden
,
size
=
d_hid
,
num_flatten_dims
=
2
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_fc_1.w_0'
,
initializer
=
param_initializer
),
bias_attr
=
name
+
'_fc_1.b_0'
)
return
out
def
pre_post_process_layer
(
prev_out
,
out
,
process_cmd
,
dropout_rate
=
0.
,
name
=
''
):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for
cmd
in
process_cmd
:
if
cmd
==
"a"
:
# add residual connection
out
=
out
+
prev_out
if
prev_out
else
out
elif
cmd
==
"n"
:
# add layer normalization
out_dtype
=
out
.
dtype
if
out_dtype
==
fluid
.
core
.
VarDesc
.
VarType
.
FP16
:
out
=
layers
.
cast
(
x
=
out
,
dtype
=
"float32"
)
out
=
layers
.
layer_norm
(
out
,
begin_norm_axis
=
len
(
out
.
shape
)
-
1
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_layer_norm_scale'
,
initializer
=
fluid
.
initializer
.
Constant
(
1.
)),
bias_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_layer_norm_bias'
,
initializer
=
fluid
.
initializer
.
Constant
(
0.
)))
if
out_dtype
==
fluid
.
core
.
VarDesc
.
VarType
.
FP16
:
out
=
layers
.
cast
(
x
=
out
,
dtype
=
"float16"
)
elif
cmd
==
"d"
:
# add dropout
if
dropout_rate
:
out
=
layers
.
dropout
(
out
,
dropout_prob
=
dropout_rate
,
dropout_implementation
=
"upscale_in_train"
,
is_test
=
False
)
return
out
pre_process_layer
=
partial
(
pre_post_process_layer
,
None
)
post_process_layer
=
pre_post_process_layer
def
encoder_layer
(
enc_input
,
attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
hidden_act
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
,
param_initializer
=
None
,
name
=
''
):
"""
The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output
=
multi_head_attention
(
pre_process_layer
(
enc_input
,
preprocess_cmd
,
prepostprocess_dropout
,
name
=
name
+
'_pre_att'
),
None
,
None
,
attn_bias
,
d_key
,
d_value
,
d_model
,
n_head
,
attention_dropout
,
param_initializer
=
param_initializer
,
name
=
name
+
'_multi_head_att'
)
attn_output
=
post_process_layer
(
enc_input
,
attn_output
,
postprocess_cmd
,
prepostprocess_dropout
,
name
=
name
+
'_post_att'
)
ffd_output
=
positionwise_feed_forward
(
pre_process_layer
(
attn_output
,
preprocess_cmd
,
prepostprocess_dropout
,
name
=
name
+
'_pre_ffn'
),
d_inner_hid
,
d_model
,
relu_dropout
,
hidden_act
,
param_initializer
=
param_initializer
,
name
=
name
+
'_ffn'
)
return
post_process_layer
(
attn_output
,
ffd_output
,
postprocess_cmd
,
prepostprocess_dropout
,
name
=
name
+
'_post_ffn'
)
def
encoder
(
enc_input
,
attn_bias
,
n_layer
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
hidden_act
,
preprocess_cmd
=
"n"
,
postprocess_cmd
=
"da"
,
param_initializer
=
None
,
name
=
''
,
return_all
=
False
):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
enc_outputs
=
[]
for
i
in
range
(
n_layer
):
enc_output
=
encoder_layer
(
enc_input
,
attn_bias
,
n_head
,
d_key
,
d_value
,
d_model
,
d_inner_hid
,
prepostprocess_dropout
,
attention_dropout
,
relu_dropout
,
hidden_act
,
preprocess_cmd
,
postprocess_cmd
,
param_initializer
=
param_initializer
,
name
=
name
+
'_layer_'
+
str
(
i
))
enc_input
=
enc_output
if
i
<
n_layer
-
1
:
enc_outputs
.
append
(
enc_output
)
enc_output
=
pre_process_layer
(
enc_output
,
preprocess_cmd
,
prepostprocess_dropout
,
name
=
"post_encoder"
)
enc_outputs
.
append
(
enc_output
)
if
not
return_all
:
return
enc_output
else
:
return
enc_output
,
enc_outputs
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录