Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
2555c0e2
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
5
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2555c0e2
编写于
6月 08, 2020
作者:
X
xixiaoyao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ernie argument
上级
bbbb7357
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
26 addition
and
15 deletion
+26
-15
paddlepalm/backbone/ernie.py
paddlepalm/backbone/ernie.py
+16
-10
paddlepalm/head/base_head.py
paddlepalm/head/base_head.py
+3
-3
paddlepalm/head/match.py
paddlepalm/head/match.py
+4
-2
paddlepalm/trainer.py
paddlepalm/trainer.py
+3
-0
未找到文件。
paddlepalm/backbone/ernie.py
浏览文件 @
2555c0e2
...
...
@@ -31,7 +31,7 @@ class ERNIE(Backbone):
def
__init__
(
self
,
hidden_size
,
num_hidden_layers
,
num_attention_heads
,
vocab_size
,
\
max_position_embeddings
,
sent_type_vocab_size
,
task_type_vocab_size
,
\
hidden_act
,
hidden_dropout_prob
,
attention_probs_dropout_prob
,
initializer_range
,
is_pairwise
=
False
,
phase
=
'train'
):
hidden_act
,
hidden_dropout_prob
,
attention_probs_dropout_prob
,
initializer_range
,
is_pairwise
=
False
,
use_task_emb
=
True
,
phase
=
'train'
):
# self._is_training = phase == 'train' # backbone一般不用关心运行阶段,因为outputs在任何阶段基本不会变
...
...
@@ -54,6 +54,7 @@ class ERNIE(Backbone):
self
.
_task_emb_name
=
"task_embedding"
self
.
_emb_dtype
=
"float32"
self
.
_is_pairwise
=
is_pairwise
self
.
_use_task_emb
=
use_task_emb
self
.
_phase
=
phase
self
.
_param_initializer
=
fluid
.
initializer
.
TruncatedNormal
(
scale
=
initializer_range
)
...
...
@@ -85,6 +86,10 @@ class ERNIE(Backbone):
task_type_vocab_size
=
config
[
'task_type_vocab_size'
]
else
:
task_type_vocab_size
=
config
[
'type_vocab_size'
]
if
'use_task_emb'
in
config
:
use_task_emb
=
config
[
'use_task_emb'
]
else
:
use_task_emb
=
True
hidden_act
=
config
[
'hidden_act'
]
hidden_dropout_prob
=
config
[
'hidden_dropout_prob'
]
attention_probs_dropout_prob
=
config
[
'attention_probs_dropout_prob'
]
...
...
@@ -96,7 +101,7 @@ class ERNIE(Backbone):
return
cls
(
hidden_size
,
num_hidden_layers
,
num_attention_heads
,
vocab_size
,
\
max_position_embeddings
,
sent_type_vocab_size
,
task_type_vocab_size
,
\
hidden_act
,
hidden_dropout_prob
,
attention_probs_dropout_prob
,
initializer_range
,
is_pairwise
,
phase
=
phase
)
hidden_act
,
hidden_dropout_prob
,
attention_probs_dropout_prob
,
initializer_range
,
is_pairwise
,
use_task_emb
=
use_task_emb
,
phase
=
phase
)
@
property
def
inputs_attr
(
self
):
...
...
@@ -180,6 +185,7 @@ class ERNIE(Backbone):
emb_out
=
emb_out
+
position_emb_out
emb_out
=
emb_out
+
sent_emb_out
if
self
.
_use_task_emb
:
task_emb_out
=
fluid
.
embedding
(
task_ids
,
size
=
[
self
.
_task_types
,
self
.
_emb_size
],
...
...
paddlepalm/head/base_head.py
浏览文件 @
2555c0e2
...
...
@@ -122,11 +122,11 @@ class Head(object):
output_dir: 积累结果的保存路径。
"""
if
output_dir
is
not
None
:
for
i
in
self
.
_results_buffer
:
print
(
i
)
else
:
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
with
open
(
os
.
path
.
join
(
output_dir
,
self
.
_phase
),
'w'
)
as
writer
:
for
i
in
self
.
_results_buffer
:
writer
.
write
(
json
.
dumps
(
i
)
+
'
\n
'
)
else
:
return
self
.
_results_buffer
paddlepalm/head/match.py
浏览文件 @
2555c0e2
...
...
@@ -159,8 +159,6 @@ class Match(Head):
else
:
return
{
'probs'
:
pos_score
}
def
batch_postprocess
(
self
,
rt_outputs
):
if
not
self
.
_is_training
:
probs
=
[]
...
...
@@ -171,6 +169,10 @@ class Match(Head):
logits
=
rt_outputs
[
'logits'
]
self
.
_preds_logits
.
extend
(
logits
.
tolist
())
def
reset
(
self
):
self
.
_preds_logits
=
[]
self
.
_preds
=
[]
def
epoch_postprocess
(
self
,
post_inputs
,
output_dir
=
None
):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if
not
self
.
_is_training
:
...
...
paddlepalm/trainer.py
浏览文件 @
2555c0e2
...
...
@@ -587,6 +587,9 @@ class Trainer(object):
results
=
self
.
_pred_head
.
epoch_postprocess
({
'reader'
:
reader_outputs
},
output_dir
=
output_dir
)
return
results
def
reset_buffer
(
self
):
self
.
_pred_head
.
reset
()
def
_check_phase
(
self
,
phase
):
assert
phase
in
[
'train'
,
'predict'
],
"Supported phase: train, predict,"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录