Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
0f062464
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
4
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0f062464
编写于
4月 24, 2020
作者:
X
Xiaoyao Xi
提交者:
GitHub
4月 24, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #79 from wangxiao1021/api
remove dropout in predict, fix
#77
, update postprocess
上级
82874d8f
bb9803a0
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
61 addition
and
53 deletion
+61
-53
paddlepalm/backbone/bert.py
paddlepalm/backbone/bert.py
+2
-2
paddlepalm/backbone/ernie.py
paddlepalm/backbone/ernie.py
+2
-2
paddlepalm/head/cls.py
paddlepalm/head/cls.py
+12
-9
paddlepalm/head/match.py
paddlepalm/head/match.py
+15
-12
paddlepalm/head/mlm.py
paddlepalm/head/mlm.py
+9
-7
paddlepalm/head/mrc.py
paddlepalm/head/mrc.py
+15
-15
paddlepalm/head/ner.py
paddlepalm/head/ner.py
+6
-6
未找到文件。
paddlepalm/backbone/bert.py
浏览文件 @
0f062464
...
@@ -42,8 +42,8 @@ class BERT(Backbone):
...
@@ -42,8 +42,8 @@ class BERT(Backbone):
self
.
_hidden_act
=
hidden_act
self
.
_hidden_act
=
hidden_act
self
.
_prepostprocess_dropout
=
hidden_dropout_prob
self
.
_prepostprocess_dropout
=
0.
if
phase
==
'predict'
else
hidden_dropout_prob
self
.
_attention_dropout
=
attention_probs_dropout_prob
self
.
_attention_dropout
=
0.
if
phase
==
'predict'
else
attention_probs_dropout_prob
self
.
_word_emb_name
=
"word_embedding"
self
.
_word_emb_name
=
"word_embedding"
self
.
_pos_emb_name
=
"pos_embedding"
self
.
_pos_emb_name
=
"pos_embedding"
...
...
paddlepalm/backbone/ernie.py
浏览文件 @
0f062464
...
@@ -45,8 +45,8 @@ class ERNIE(Backbone):
...
@@ -45,8 +45,8 @@ class ERNIE(Backbone):
self
.
_task_types
=
task_type_vocab_size
self
.
_task_types
=
task_type_vocab_size
self
.
_hidden_act
=
hidden_act
self
.
_hidden_act
=
hidden_act
self
.
_prepostprocess_dropout
=
hidden_dropout_prob
self
.
_prepostprocess_dropout
=
0.
if
phase
==
'predict'
else
hidden_dropout_prob
self
.
_attention_dropout
=
attention_probs_dropout_prob
self
.
_attention_dropout
=
0.
if
phase
==
'predict'
else
attention_probs_dropout_prob
self
.
_word_emb_name
=
"word_embedding"
self
.
_word_emb_name
=
"word_embedding"
self
.
_pos_emb_name
=
"pos_embedding"
self
.
_pos_emb_name
=
"pos_embedding"
...
...
paddlepalm/head/cls.py
浏览文件 @
0f062464
...
@@ -94,14 +94,17 @@ class Classify(Head):
...
@@ -94,14 +94,17 @@ class Classify(Head):
def
epoch_postprocess
(
self
,
post_inputs
,
output_dir
=
None
):
def
epoch_postprocess
(
self
,
post_inputs
,
output_dir
=
None
):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if
not
self
.
_is_training
:
if
not
self
.
_is_training
:
if
output_dir
is
None
:
results
=
[]
raise
ValueError
(
'argument output_dir not found in config. Please add it into config dict/file.'
)
with
open
(
os
.
path
.
join
(
output_dir
,
'predictions.json'
),
'w'
)
as
writer
:
for
i
in
range
(
len
(
self
.
_preds
)):
for
i
in
range
(
len
(
self
.
_preds
)):
label
=
int
(
np
.
argmax
(
np
.
array
(
self
.
_preds
[
i
])))
label
=
int
(
np
.
argmax
(
np
.
array
(
self
.
_preds
[
i
])))
result
=
{
'index'
:
i
,
'label'
:
label
,
'logits'
:
self
.
_preds
[
i
],
'probs'
:
self
.
_probs
[
i
]}
result
=
{
'index'
:
i
,
'label'
:
label
,
'logits'
:
self
.
_preds
[
i
],
'probs'
:
self
.
_probs
[
i
]}
results
.
append
(
result
)
if
output_dir
is
not
None
:
with
open
(
os
.
path
.
join
(
output_dir
,
'predictions.json'
),
'w'
)
as
writer
:
for
result
in
results
:
result
=
json
.
dumps
(
result
)
result
=
json
.
dumps
(
result
)
writer
.
write
(
result
+
'
\n
'
)
writer
.
write
(
result
+
'
\n
'
)
print
(
'Predictions saved at '
+
os
.
path
.
join
(
output_dir
,
'predictions.json'
))
print
(
'Predictions saved at '
+
os
.
path
.
join
(
output_dir
,
'predictions.json'
))
return
results
paddlepalm/head/match.py
浏览文件 @
0f062464
...
@@ -174,15 +174,18 @@ class Match(Head):
...
@@ -174,15 +174,18 @@ class Match(Head):
def
epoch_postprocess
(
self
,
post_inputs
,
output_dir
=
None
):
def
epoch_postprocess
(
self
,
post_inputs
,
output_dir
=
None
):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if
not
self
.
_is_training
:
if
not
self
.
_is_training
:
if
output_dir
is
None
:
results
=
[]
raise
ValueError
(
'argument output_dir not found in config. Please add it into config dict/file.'
)
with
open
(
os
.
path
.
join
(
output_dir
,
'predictions.json'
),
'w'
)
as
writer
:
for
i
in
range
(
len
(
self
.
_preds
)):
for
i
in
range
(
len
(
self
.
_preds
)):
if
self
.
_learning_strategy
==
'pointwise'
:
if
self
.
_learning_strategy
==
'pointwise'
:
label
=
int
(
np
.
argmax
(
np
.
array
(
self
.
_preds
[
i
])))
label
=
int
(
np
.
argmax
(
np
.
array
(
self
.
_preds
[
i
])))
result
=
{
'index'
:
i
,
'label'
:
label
,
'logits'
:
self
.
_preds_logits
[
i
],
'probs'
:
self
.
_preds
[
i
]}
result
=
{
'index'
:
i
,
'label'
:
label
,
'logits'
:
self
.
_preds_logits
[
i
],
'probs'
:
self
.
_preds
[
i
]}
elif
self
.
_learning_strategy
==
'pairwise'
:
elif
self
.
_learning_strategy
==
'pairwise'
:
result
=
{
'index'
:
i
,
'probs'
:
self
.
_preds
[
i
][
0
]}
result
=
{
'index'
:
i
,
'probs'
:
self
.
_preds
[
i
][
0
]}
results
.
append
(
result
)
if
output_dir
is
not
None
:
with
open
(
os
.
path
.
join
(
output_dir
,
'predictions.json'
),
'w'
)
as
writer
:
for
result
in
results
:
result
=
json
.
dumps
(
result
,
ensure_ascii
=
False
)
result
=
json
.
dumps
(
result
,
ensure_ascii
=
False
)
writer
.
write
(
result
+
'
\n
'
)
writer
.
write
(
result
+
'
\n
'
)
print
(
'Predictions saved at '
+
os
.
path
.
join
(
output_dir
,
'predictions.json'
))
print
(
'Predictions saved at '
+
os
.
path
.
join
(
output_dir
,
'predictions.json'
))
return
results
paddlepalm/head/mlm.py
浏览文件 @
0f062464
...
@@ -128,13 +128,15 @@ class MaskLM(Head):
...
@@ -128,13 +128,15 @@ class MaskLM(Head):
def
epoch_postprocess
(
self
,
post_inputs
,
output_dir
=
None
):
def
epoch_postprocess
(
self
,
post_inputs
,
output_dir
=
None
):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if
not
self
.
_is_training
:
if
not
self
.
_is_training
:
if
output_dir
is
None
:
results
=
[]
for
p
in
self
.
_preds
:
for
i
in
range
(
len
(
self
.
_preds
)):
print
(
p
)
result
=
{
'index'
:
i
,
'word_id'
:
self
.
_preds
[
i
]}
else
:
results
.
append
(
result
)
if
output_dir
is
not
None
:
with
open
(
os
.
path
.
join
(
output_dir
,
'predictions.json'
),
'w'
)
as
writer
:
with
open
(
os
.
path
.
join
(
output_dir
,
'predictions.json'
),
'w'
)
as
writer
:
for
p
in
self
.
_preds
:
for
result
in
results
:
writer
.
write
(
str
(
p
)
+
'
\n
'
)
result
=
json
.
dumps
(
result
)
writer
.
write
(
result
+
'
\n
'
)
print
(
'Predictions saved at '
+
os
.
path
.
join
(
output_dir
,
'predictions.json'
))
print
(
'Predictions saved at '
+
os
.
path
.
join
(
output_dir
,
'predictions.json'
))
return
results
paddlepalm/head/mrc.py
浏览文件 @
0f062464
...
@@ -154,8 +154,7 @@ class MRC(Head):
...
@@ -154,8 +154,7 @@ class MRC(Head):
"""(optional interface) this func will be called after evaluation/predicting process and each epoch during training process."""
"""(optional interface) this func will be called after evaluation/predicting process and each epoch during training process."""
if
not
self
.
_is_training
:
if
not
self
.
_is_training
:
if
output_dir
is
None
:
if
output_dir
is
not
None
:
raise
ValueError
(
'argument output_dir not found in config. Please add it into config dict/file.'
)
examples
=
post_inputs
[
'reader'
][
'examples'
]
examples
=
post_inputs
[
'reader'
][
'examples'
]
features
=
post_inputs
[
'reader'
][
'features'
]
features
=
post_inputs
[
'reader'
][
'features'
]
if
not
os
.
path
.
exists
(
output_dir
):
if
not
os
.
path
.
exists
(
output_dir
):
...
@@ -169,6 +168,7 @@ class MRC(Head):
...
@@ -169,6 +168,7 @@ class MRC(Head):
output_nbest_file
,
output_null_log_odds_file
,
output_nbest_file
,
output_null_log_odds_file
,
self
.
_with_negative
,
self
.
_with_negative
,
self
.
_null_score_diff_threshold
,
self
.
_verbose
)
self
.
_null_score_diff_threshold
,
self
.
_verbose
)
return
self
.
_pred_results
def
_write_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
def
_write_predictions
(
all_examples
,
all_features
,
all_results
,
n_best_size
,
...
...
paddlepalm/head/ner.py
浏览文件 @
0f062464
...
@@ -118,9 +118,9 @@ class SequenceLabel(Head):
...
@@ -118,9 +118,9 @@ class SequenceLabel(Head):
def
epoch_postprocess
(
self
,
post_inputs
,
output_dir
=
None
):
def
epoch_postprocess
(
self
,
post_inputs
,
output_dir
=
None
):
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
# there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs
if
not
self
.
_is_training
:
if
not
self
.
_is_training
:
if
output_dir
is
None
:
if
output_dir
is
not
None
:
raise
ValueError
(
'argument output_dir not found in config. Please add it into config dict/file.'
)
with
open
(
os
.
path
.
join
(
output_dir
,
'predictions.json'
),
'w'
)
as
writer
:
with
open
(
os
.
path
.
join
(
output_dir
,
'predictions.json'
),
'w'
)
as
writer
:
for
p
in
self
.
_preds
:
for
p
in
self
.
_preds
:
writer
.
write
(
str
(
p
)
+
'
\n
'
)
writer
.
write
(
str
(
p
)
+
'
\n
'
)
print
(
'Predictions saved at '
+
os
.
path
.
join
(
output_dir
,
'predictions.json'
))
print
(
'Predictions saved at '
+
os
.
path
.
join
(
output_dir
,
'predictions.json'
))
return
self
.
_preds
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录