Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
d71b37d0
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
7
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d71b37d0
编写于
2月 04, 2020
作者:
W
wangxiao1021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs
上级
1bb38efb
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
372 addition
and
358 deletion
+372
-358
examples/classification/README.md
examples/classification/README.md
+1
-1
examples/matching/README.md
examples/matching/README.md
+2
-2
examples/mrc/README.md
examples/mrc/README.md
+2
-1
examples/multi-task/download.py
examples/multi-task/download.py
+7
-12
examples/multi-task/evaluate.py
examples/multi-task/evaluate.py
+74
-165
examples/multi-task/predict.py
examples/multi-task/predict.py
+59
-0
examples/multi-task/process.py
examples/multi-task/process.py
+122
-32
examples/multi-task/run.py
examples/multi-task/run.py
+37
-66
examples/predict/README.md
examples/predict/README.md
+1
-1
examples/tagging/README.md
examples/tagging/README.md
+1
-1
examples/tagging/run.py
examples/tagging/run.py
+21
-21
paddlepalm/head/mlm.py
paddlepalm/head/mlm.py
+8
-11
paddlepalm/multihead_trainer.py
paddlepalm/multihead_trainer.py
+13
-34
paddlepalm/reader/mlm.py
paddlepalm/reader/mlm.py
+0
-1
paddlepalm/reader/utils/reader4ernie.py
paddlepalm/reader/utils/reader4ernie.py
+1
-1
paddlepalm/trainer.py
paddlepalm/trainer.py
+23
-9
未找到文件。
examples/classification/README.md
浏览文件 @
d71b37d0
...
...
@@ -32,7 +32,7 @@ label text_a
### Step 2: Train & Predict
The code used to perform
classification
task is in
`run.py`
. If you have prepared the pre-training model and the data set required for the task, run:
The code used to perform
this
task is in
`run.py`
. If you have prepared the pre-training model and the data set required for the task, run:
```
shell
python run.py
...
...
examples/matching/README.md
浏览文件 @
d71b37d0
...
...
@@ -21,7 +21,7 @@ python download.py
After the dataset is downloaded, you should convert the data format for training:
```
shell
python process.py
quora_duplicate_questions.tsv train.tsv
test.tsv
python process.py
data/quora_duplicate_questions.tsv data/train.tsv data/
test.tsv
```
If everything goes well, there will be a folder named
`data/`
created with all the converted datas in it.
...
...
@@ -40,7 +40,7 @@ What are the differences between the Dell Inspiron 3000, 5000, and 7000 series l
### Step 2: Train & Predict
The code used to perform
classification
task is in
`run.py`
. If you have prepared the pre-training model and the data set required for the task, run:
The code used to perform
this
task is in
`run.py`
. If you have prepared the pre-training model and the data set required for the task, run:
```
shell
python run.py
...
...
examples/mrc/README.md
浏览文件 @
d71b37d0
...
...
@@ -39,12 +39,13 @@ Here is some example datas:
}
]
}
}
```
### Step 2: Train & Predict
The code used to perform
classification
task is in
`run.py`
. If you have prepared the pre-training model and the data set required for the task, run:
The code used to perform
this
task is in
`run.py`
. If you have prepared the pre-training model and the data set required for the task, run:
```
shell
python run.py
...
...
examples/multi-task/download.py
浏览文件 @
d71b37d0
...
...
@@ -28,8 +28,8 @@ def download(src, url):
abs_path
=
os
.
path
.
abspath
(
__file__
)
download_url
=
"https://
ernie.bj.bcebos.com/task_data_zh.t
gz"
downlaod_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
abs_path
),
"
task_data_zh.t
gz"
)
download_url
=
"https://
baidu-nlp.bj.bcebos.com/dmtk_data_1.0.0.tar.
gz"
downlaod_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
abs_path
),
"
dmtk_data_1.0.0.tar.
gz"
)
target_dir
=
os
.
path
.
dirname
(
abs_path
)
download
(
downlaod_path
,
download_url
)
...
...
@@ -37,14 +37,9 @@ tar = tarfile.open(downlaod_path)
tar
.
extractall
(
target_dir
)
os
.
remove
(
downlaod_path
)
abs_path
=
os
.
path
.
abspath
(
__file__
)
dst_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
abs_path
),
"data/mrc"
)
if
not
os
.
path
.
exists
(
dst_dir
)
or
not
os
.
path
.
isdir
(
dst_dir
):
os
.
makedirs
(
dst_dir
)
for
file
in
os
.
listdir
(
os
.
path
.
join
(
target_dir
,
'task_data'
,
'cmrc2018'
)):
shutil
.
move
(
os
.
path
.
join
(
target_dir
,
'task_data'
,
'cmrc2018'
,
file
),
dst_dir
)
shutil
.
rmtree
(
os
.
path
.
join
(
target_dir
,
'task_data'
))
shutil
.
rmtree
(
os
.
path
.
join
(
target_dir
,
'data/dstc2/'
))
shutil
.
rmtree
(
os
.
path
.
join
(
target_dir
,
'data/mrda/'
))
shutil
.
rmtree
(
os
.
path
.
join
(
target_dir
,
'data/multi-woz/'
))
shutil
.
rmtree
(
os
.
path
.
join
(
target_dir
,
'data/swda/'
))
shutil
.
rmtree
(
os
.
path
.
join
(
target_dir
,
'data/udc/'
))
examples/multi-task/evaluate.py
浏览文件 @
d71b37d0
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''
Evaluation script for CMRC 2018
version: v5
Note:
v5 formatted output, add usage description
v4 fixed segmentation issues
'''
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
from
collections
import
Counter
,
OrderedDict
import
string
import
re
import
argparse
import
json
import
sys
import
nltk
import
pdb
# split Chinese with English
def
mixed_segmentation
(
in_str
,
rm_punc
=
False
):
in_str
=
in_str
.
lower
().
strip
()
segs_out
=
[]
temp_str
=
""
sp_char
=
[
'-'
,
':'
,
'_'
,
'*'
,
'^'
,
'/'
,
'
\\
'
,
'~'
,
'`'
,
'+'
,
'='
,
','
,
'。'
,
':'
,
'?'
,
'!'
,
'“'
,
'”'
,
';'
,
'’'
,
'《'
,
'》'
,
'……'
,
'·'
,
'、'
,
'「'
,
'」'
,
'('
,
')'
,
'-'
,
'~'
,
'『'
,
'』'
]
for
char
in
in_str
:
if
rm_punc
and
char
in
sp_char
:
continue
if
re
.
search
(
r
'[\u4e00-\u9fa5]'
,
char
)
or
char
in
sp_char
:
if
temp_str
!=
""
:
ss
=
nltk
.
word_tokenize
(
temp_str
)
segs_out
.
extend
(
ss
)
temp_str
=
""
segs_out
.
append
(
char
)
else
:
temp_str
+=
char
#handling last part
if
temp_str
!=
""
:
ss
=
nltk
.
word_tokenize
(
temp_str
)
segs_out
.
extend
(
ss
)
return
segs_out
def
load_label_map
(
map_dir
=
"./data/atis/atis_slot/label_map.json"
):
"""
:param map_dir: dict indictuing chunk type
:return:
"""
return
json
.
load
(
open
(
map_dir
,
"r"
))
# remove punctuation
def
remove_punctuation
(
in_str
):
in_str
=
in_str
.
lower
().
strip
()
sp_char
=
[
'-'
,
':'
,
'_'
,
'*'
,
'^'
,
'/'
,
'
\\
'
,
'~'
,
'`'
,
'+'
,
'='
,
','
,
'。'
,
':'
,
'?'
,
'!'
,
'“'
,
'”'
,
';'
,
'’'
,
'《'
,
'》'
,
'……'
,
'·'
,
'、'
,
'「'
,
'」'
,
'('
,
')'
,
'-'
,
'~'
,
'『'
,
'』'
]
out_segs
=
[]
for
char
in
in_str
:
if
char
in
sp_char
:
continue
else
:
out_segs
.
append
(
char
)
return
''
.
join
(
out_segs
)
def
cal_chunk
(
total_res
,
total_label
):
assert
len
(
total_label
)
==
len
(
total_res
),
'prediction result doesn
\'
t match to labels'
num_labels
=
0
num_corr
=
0
num_infers
=
0
for
res
,
label
in
zip
(
total_res
,
total_label
):
assert
len
(
res
)
==
len
(
label
),
"prediction result doesn
\'
t match to labels"
num_labels
+=
sum
([
0
if
i
==
6
else
1
for
i
in
label
])
num_corr
+=
sum
([
1
if
label
[
i
]
==
res
[
i
]
and
label
[
i
]
!=
6
else
0
for
i
in
range
(
len
(
label
))])
num_infers
+=
sum
([
0
if
i
==
6
else
1
for
i
in
res
])
precision
=
num_corr
*
1.0
/
num_infers
if
num_infers
>
0
else
0.0
recall
=
num_corr
*
1.0
/
num_labels
if
num_labels
>
0
else
0.0
f1
=
2
*
precision
*
recall
/
(
precision
+
recall
)
if
precision
+
recall
>
0
else
0.0
# find longest common string
def
find_lcs
(
s1
,
s2
):
m
=
[[
0
for
i
in
range
(
len
(
s2
)
+
1
)]
for
j
in
range
(
len
(
s1
)
+
1
)]
mmax
=
0
p
=
0
for
i
in
range
(
len
(
s1
)):
for
j
in
range
(
len
(
s2
)):
if
s1
[
i
]
==
s2
[
j
]:
m
[
i
+
1
][
j
+
1
]
=
m
[
i
][
j
]
+
1
if
m
[
i
+
1
][
j
+
1
]
>
mmax
:
mmax
=
m
[
i
+
1
][
j
+
1
]
p
=
i
+
1
return
s1
[
p
-
mmax
:
p
],
mmax
return
precision
,
recall
,
f1
#
def
evaluate
(
ground_truth_file
,
prediction_file
):
f1
=
0
em
=
0
total_count
=
0
skip_count
=
0
for
instances
in
ground_truth_file
[
"data"
]:
for
instance
in
instances
[
"paragraphs"
]:
context_text
=
instance
[
'context'
].
strip
()
for
qas
in
instance
[
'qas'
]:
total_count
+=
1
query_id
=
qas
[
'id'
].
strip
()
query_text
=
qas
[
'question'
].
strip
()
answers
=
[
ans
[
"text"
]
for
ans
in
qas
[
"answers"
]]
def
res_evaluate
(
res_dir
=
"./outputs/predict/predictions.json"
,
data_dir
=
"./data/atis/atis_slot/test.tsv"
):
label_map
=
load_label_map
()
if
query_id
not
in
prediction_file
:
print
(
'Unanswered question: {}
\n
'
.
format
(
query_id
))
skip_count
+=
1
total_label
=
[]
with
open
(
data_dir
,
"r"
)
as
file
:
first_flag
=
True
for
line
in
file
:
if
first_flag
:
first_flag
=
False
continue
prediction
=
prediction_file
[
query_id
]
f1
+=
calc_f1_score
(
answers
,
prediction
)
em
+=
calc_em_score
(
answers
,
prediction
)
f1_score
=
100.0
*
f1
/
total_count
em_score
=
100.0
*
em
/
total_count
return
f1_score
,
em_score
,
total_count
,
skip_count
def
calc_f1_score
(
answers
,
prediction
):
f1_scores
=
[]
for
ans
in
answers
:
ans_segs
=
mixed_segmentation
(
ans
,
rm_punc
=
True
)
prediction_segs
=
mixed_segmentation
(
prediction
,
rm_punc
=
True
)
lcs
,
lcs_len
=
find_lcs
(
ans_segs
,
prediction_segs
)
if
lcs_len
==
0
:
f1_scores
.
append
(
0
)
line
=
line
.
strip
(
"
\n
"
)
if
len
(
line
)
==
0
:
continue
precision
=
1.0
*
lcs_len
/
len
(
prediction_segs
)
recall
=
1.0
*
lcs_len
/
len
(
ans_segs
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
f1_scores
.
append
(
f1
)
return
max
(
f1_score
s
)
def
calc_em_score
(
answers
,
prediction
):
em
=
0
for
ans
in
answers
:
ans_
=
remove_punctuation
(
ans
)
prediction_
=
remove_punctuation
(
prediction
)
if
ans_
==
prediction_
:
em
=
1
break
return
em
line
=
line
.
split
(
"
\t
"
)
if
len
(
line
)
<
2
:
continue
labels
=
line
[
1
][:
-
1
].
split
(
"
\x02
"
)
total_label
.
append
(
label
s
)
total_label
=
[[
label_map
[
j
]
for
j
in
i
]
for
i
in
total_label
]
total_res
=
[]
with
open
(
res_dir
,
"r"
)
as
file
:
cnt
=
0
for
line
in
file
:
line
=
line
.
strip
(
"
\n
"
)
if
len
(
line
)
==
0
:
continue
try
:
res_arr
=
json
.
loads
(
line
)
if
len
(
total_label
[
cnt
])
<
len
(
res_arr
):
total_res
.
append
(
res_arr
[
1
:
1
+
len
(
total_label
[
cnt
])])
elif
len
(
total_label
[
cnt
])
==
len
(
res_arr
):
total_res
.
append
(
res_arr
)
else
:
total_res
.
append
(
res_arr
)
total_label
[
cnt
]
=
total_label
[
cnt
][:
len
(
res_arr
)]
except
:
print
(
"json format error: {}"
.
format
(
cnt
))
print
(
line
)
def
eval_file
(
dataset_file
,
prediction_file
):
ground_truth_file
=
json
.
load
(
open
(
dataset_file
,
'r'
))
prediction_file
=
json
.
load
(
open
(
prediction_file
,
'r'
))
F1
,
EM
,
TOTAL
,
SKIP
=
evaluate
(
ground_truth_file
,
prediction_file
)
AVG
=
(
EM
+
F1
)
*
0.5
return
EM
,
F1
,
AVG
,
TOTAL
cnt
+=
1
precision
,
recall
,
f1
=
cal_chunk
(
total_res
,
total_label
)
print
(
"precision: {}, recall: {}, f1: {}"
.
format
(
precision
,
recall
,
f1
))
if
__name__
==
'__main__'
:
EM
,
F1
,
AVG
,
TOTAL
=
eval_file
(
"task_data/cmrc2018/dev.json"
,
"predictions.json"
)
print
(
EM
)
print
(
F1
)
print
(
TOTAL
)
\ No newline at end of file
res_evaluate
()
examples/multi-task/predict.py
0 → 100644
浏览文件 @
d71b37d0
# coding=utf-8
import
paddlepalm
as
palm
import
json
from
paddlepalm.distribute
import
gpu_dev_count
if
__name__
==
'__main__'
:
# configs
max_seqlen
=
256
batch_size
=
16
num_epochs
=
6
print_steps
=
5
lr
=
5e-5
num_classes
=
130
random_seed
=
1
label_map
=
'./data/atis/atis_slot/label_map.json'
vocab_path
=
'./pretrain/ernie-en-base/vocab.txt'
predict_file
=
'./data/atis/atis_slot/test.tsv'
save_path
=
'./outputs/'
pred_output
=
'./outputs/predict/'
save_type
=
'ckpt'
pre_params
=
'./pretrain/ernie-en-base/params'
config
=
json
.
load
(
open
(
'./pretrain/ernie-en-base/ernie_config.json'
))
input_dim
=
config
[
'hidden_size'
]
# ----------------------- for prediction -----------------------
# step 1-1: create readers for prediction
print
(
'prepare to predict...'
)
predict_seq_label_reader
=
palm
.
reader
.
SequenceLabelReader
(
vocab_path
,
max_seqlen
,
label_map
,
phase
=
'predict'
)
# step 1-2: load the training data
predict_seq_label_reader
.
load_data
(
predict_file
,
batch_size
)
# step 2: create a backbone of the model to extract text features
pred_ernie
=
palm
.
backbone
.
ERNIE
.
from_config
(
config
,
phase
=
'predict'
)
# step 3: register the backbone in reader
predict_seq_label_reader
.
register_with
(
pred_ernie
)
# step 4: create the task output head
seq_label_pred_head
=
palm
.
head
.
SequenceLabel
(
num_classes
,
input_dim
,
phase
=
'predict'
)
# step 5-1: create a task trainer
trainer_seq_label
=
palm
.
Trainer
(
"slot"
)
# step 5-2: build forward graph with backbone and task head
trainer_seq_label
.
build_predict_forward
(
pred_ernie
,
seq_label_pred_head
)
# step 6: load pretrained model
pred_model_path
=
'./outputs/1580822697.73-ckpt.step9282'
pred_ckpt
=
trainer_seq_label
.
load_ckpt
(
pred_model_path
)
# step 7: fit prepared reader and data
trainer_seq_label
.
fit_reader
(
predict_seq_label_reader
,
phase
=
'predict'
)
# step 8: predict
print
(
'predicting..'
)
trainer_seq_label
.
predict
(
print_steps
=
print_steps
,
output_dir
=
pred_output
)
\ No newline at end of file
examples/multi-task/process.py
浏览文件 @
d71b37d0
# -*- coding: UTF-8 -*-
import
json
import
os
import
io
abs_path
=
os
.
path
.
abspath
(
__file__
)
dst_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
abs_path
),
"data/match/"
)
if
not
os
.
path
.
exists
(
dst_dir
)
or
not
os
.
path
.
isdir
(
dst_dir
):
os
.
makedirs
(
dst_dir
)
os
.
mknod
(
"./data/match/train.tsv"
)
with
io
.
open
(
"./data/mrc/train.json"
,
"r"
,
encoding
=
'utf-8'
)
as
f
:
data
=
json
.
load
(
f
)[
"data"
]
i
=
0
with
open
(
"./data/match/train.tsv"
,
"w"
)
as
f2
:
f2
.
write
(
"text_a
\t
text_b
\t
label
\n
"
)
for
dd
in
data
:
for
d
in
dd
[
"paragraphs"
]:
context
=
d
[
"context"
]
for
qa
in
d
[
"qas"
]:
text_a
=
qa
[
"question"
]
answer
=
qa
[
"answers"
][
0
]
text_b
=
answer
[
"text"
]
start_pos
=
answer
[
"answer_start"
]
text_b_neg
=
context
[
0
:
start_pos
]
if
len
(
text_b_neg
)
>
512
:
text_b_neg
=
text_b_neg
[
-
512
:
-
1
]
l1
=
text_a
+
"
\t
"
+
text_b
+
"
\t
1
\n
"
l2
=
text_a
+
"
\t
"
+
text_b_neg
+
"
\t
0
\n
"
if
i
<
14246
:
f2
.
write
(
l1
.
encode
(
"utf-8"
))
f2
.
write
(
l2
.
encode
(
"utf-8"
))
i
+=
2
import
json
label_new
=
"data/atis/atis_slot/label_map.json"
label_old
=
"data/atis/atis_slot/map_tag_slot_id.txt"
train_old
=
"data/atis/atis_slot/train.txt"
train_new
=
"data/atis/atis_slot/train.tsv"
dev_old
=
"data/atis/atis_slot/dev.txt"
dev_new
=
"data/atis/atis_slot/dev.tsv"
test_old
=
"data/atis/atis_slot/test.txt"
test_new
=
"data/atis/atis_slot/test.tsv"
intent_test
=
"data/atis/atis_intent/test.tsv"
os
.
rename
(
"data/atis/atis_intent/test.txt"
,
intent_test
)
intent_train
=
"data/atis/atis_intent/train.tsv"
os
.
rename
(
"data/atis/atis_intent/train.txt"
,
intent_train
)
intent_dev
=
"data/atis/atis_intent/dev.tsv"
os
.
rename
(
"data/atis/atis_intent/dev.txt"
,
intent_dev
)
with
open
(
intent_dev
,
'r+'
)
as
f
:
content
=
f
.
read
()
f
.
seek
(
0
,
0
)
f
.
write
(
"label
\t
text_a
\n
"
+
content
)
f
.
close
()
with
open
(
intent_test
,
'r+'
)
as
f
:
content
=
f
.
read
()
f
.
seek
(
0
,
0
)
f
.
write
(
"label
\t
text_a
\n
"
+
content
)
f
.
close
()
with
open
(
intent_train
,
'r+'
)
as
f
:
content
=
f
.
read
()
f
.
seek
(
0
,
0
)
f
.
write
(
"label
\t
text_a
\n
"
+
content
)
f
.
close
()
os
.
mknod
(
label_new
)
os
.
mknod
(
train_new
)
os
.
mknod
(
dev_new
)
os
.
mknod
(
test_new
)
tag
=
[]
id
=
[]
map
=
{}
with
open
(
label_old
,
"r"
)
as
f
:
with
open
(
label_new
,
"w"
)
as
f2
:
for
line
in
f
.
readlines
():
line
=
line
.
split
(
'
\t
'
)
tag
.
append
(
line
[
0
])
id
.
append
(
int
(
line
[
1
][:
-
1
]))
map
[
line
[
1
][:
-
1
]]
=
line
[
0
]
re
=
{
tag
[
i
]:
id
[
i
]
for
i
in
range
(
len
(
tag
))}
re
=
json
.
dumps
(
re
)
f2
.
write
(
re
)
f2
.
close
()
f
.
close
()
with
open
(
train_old
,
"r"
)
as
f
:
with
open
(
train_new
,
"w"
)
as
f2
:
f2
.
write
(
"text_a
\t
label
\n
"
)
for
line
in
f
.
readlines
():
line
=
line
.
split
(
'
\t
'
)
text
=
line
[
0
].
split
(
' '
)
label
=
line
[
1
].
split
(
' '
)
for
t
in
text
:
f2
.
write
(
t
)
f2
.
write
(
'
\2
'
)
f2
.
write
(
'
\t
'
)
for
t
in
label
:
if
t
.
endswith
(
'
\n
'
):
t
=
t
[:
-
1
]
f2
.
write
(
map
[
t
])
f2
.
write
(
'
\2
'
)
f2
.
write
(
'
\n
'
)
f2
.
close
()
f
.
close
()
with
open
(
test_old
,
"r"
)
as
f
:
with
open
(
test_new
,
"w"
)
as
f2
:
f2
.
write
(
"text_a
\t
label
\n
"
)
for
line
in
f
.
readlines
():
line
=
line
.
split
(
'
\t
'
)
text
=
line
[
0
].
split
(
' '
)
label
=
line
[
1
].
split
(
' '
)
for
t
in
text
:
f2
.
write
(
t
)
f2
.
write
(
'
\2
'
)
f2
.
write
(
'
\t
'
)
for
t
in
label
:
if
t
.
endswith
(
'
\n
'
):
t
=
t
[:
-
1
]
f2
.
write
(
map
[
t
])
f2
.
write
(
'
\2
'
)
f2
.
write
(
'
\n
'
)
f2
.
close
()
f
.
close
()
with
open
(
dev_old
,
"r"
)
as
f
:
with
open
(
dev_new
,
"w"
)
as
f2
:
f2
.
write
(
"text_a
\t
label
\n
"
)
for
line
in
f
.
readlines
():
line
=
line
.
split
(
'
\t
'
)
text
=
line
[
0
].
split
(
' '
)
label
=
line
[
1
].
split
(
' '
)
for
t
in
text
:
f2
.
write
(
t
)
f2
.
write
(
'
\2
'
)
f2
.
write
(
'
\t
'
)
for
t
in
label
:
if
t
.
endswith
(
'
\n
'
):
t
=
t
[:
-
1
]
f2
.
write
(
map
[
t
])
f2
.
write
(
'
\2
'
)
f2
.
write
(
'
\n
'
)
f2
.
close
()
f
.
close
()
os
.
remove
(
label_old
)
os
.
remove
(
train_old
)
os
.
remove
(
test_old
)
os
.
remove
(
dev_old
)
\ No newline at end of file
examples/multi-task/run.py
浏览文件 @
d71b37d0
...
...
@@ -8,61 +8,61 @@ if __name__ == '__main__':
# configs
max_seqlen
=
128
batch_size
=
8
num_epochs
=
8
lr
=
3e-5
doc_stride
=
128
max_query_len
=
64
max_ans_len
=
128
batch_size
=
16
num_epochs
=
20
print_steps
=
5
lr
=
2e-5
num_classes
=
130
weight_decay
=
0.01
print_steps
=
1
num_classes
=
2
random_seed
=
1
num_classes_intent
=
26
dropout_prob
=
0.1
vocab_path
=
'./pretrain/ernie-zh-base/vocab.txt'
do_lower_case
=
True
random_seed
=
0
label_map
=
'./data/atis/atis_slot/label_map.json'
vocab_path
=
'./pretrain/ernie-en-base/vocab.txt'
train_file
=
'./data/mrc/train.json'
train_file_mlm
=
'./data/mlm/train.tsv'
train_file_match
=
'./data/match/train.tsv'
predict_file
=
'./data/mrc/dev.json'
train_slot
=
'./data/atis/atis_slot/train.tsv'
train_intent
=
'./data/atis/atis_intent/train.tsv'
predict_file
=
'./data/atis/atis_slot/test.tsv'
save_path
=
'./outputs/'
pred_output
=
'./outputs/predict/'
save_type
=
'ckpt'
task_name
=
'cmrc2018'
pre_params
=
'./pretrain/ernie-
zh
-base/params'
config
=
json
.
load
(
open
(
'./pretrain/ernie-
zh
-base/ernie_config.json'
))
pre_params
=
'./pretrain/ernie-
en
-base/params'
config
=
json
.
load
(
open
(
'./pretrain/ernie-
en
-base/ernie_config.json'
))
input_dim
=
config
[
'hidden_size'
]
# ----------------------- for training -----------------------
# step 1-1: create readers for training
mrc_reader
=
palm
.
reader
.
MRCReader
(
vocab_path
,
max_seqlen
,
max_query_len
,
doc_stride
,
do_lower_case
=
do_lower_case
)
seq_label_reader
=
palm
.
reader
.
SequenceLabelReader
(
vocab_path
,
max_seqlen
,
label_map
,
seed
=
random_seed
)
match_reader
=
palm
.
reader
.
MatchReader
(
vocab_path
,
max_seqlen
,
seed
=
random_seed
)
# step 1-2: load the training data
mrc_reader
.
load_data
(
train_file
,
file_format
=
'json
'
,
num_epochs
=
None
,
batch_size
=
batch_size
)
match_reader
.
load_data
(
train_
file_match
,
file_format
=
'tsv'
,
num_epochs
=
None
,
batch_size
=
batch_size
)
seq_label_reader
.
load_data
(
train_slot
,
file_format
=
'tsv
'
,
num_epochs
=
None
,
batch_size
=
batch_size
)
match_reader
.
load_data
(
train_
intent
,
file_format
=
'tsv'
,
num_epochs
=
None
,
batch_size
=
batch_size
)
# step 2: create a backbone of the model to extract text features
ernie
=
palm
.
backbone
.
ERNIE
.
from_config
(
config
)
# step 3: register the backbone in readers
mrc
_reader
.
register_with
(
ernie
)
seq_label
_reader
.
register_with
(
ernie
)
match_reader
.
register_with
(
ernie
)
# step 4: create task output heads
mrc_head
=
palm
.
head
.
MRC
(
max_query_len
,
config
[
'hidden_size'
],
do_lower_case
=
do_lower_case
,
max_ans_len
=
max_ans_len
)
match_head
=
palm
.
head
.
Match
(
num_classes
,
input_dim
,
dropout_prob
)
seq_label_head
=
palm
.
head
.
SequenceLabel
(
num_classes
,
input_dim
,
dropout_prob
)
match_head
=
palm
.
head
.
Match
(
num_classes
_intent
,
input_dim
,
dropout_prob
)
# step 5-1: create a task trainer
trainer_mrc
=
palm
.
Trainer
(
task_name
,
mix_ratio
=
1.0
)
trainer_match
=
palm
.
Trainer
(
"match"
,
mix_ratio
=
0.5
)
trainer
=
palm
.
MultiHeadTrainer
([
trainer_mrc
,
trainer_match
])
# step 5-2: build forward graph with backbone and task head
loss_var
=
trainer
.
build_forward
(
ernie
,
[
mrc_head
,
match_head
])
trainer_seq_label
=
palm
.
Trainer
(
"slot"
,
mix_ratio
=
1.0
)
trainer_match
=
palm
.
Trainer
(
"intent"
,
mix_ratio
=
0.5
)
trainer
=
palm
.
MultiHeadTrainer
([
trainer_seq_label
,
trainer_match
])
# # step 5-2: build forward graph with backbone and task head
loss_var1
=
trainer_match
.
build_forward
(
ernie
,
match_head
)
loss_var2
=
trainer_seq_label
.
build_forward
(
ernie
,
seq_label_head
)
loss_var
=
trainer
.
build_forward
()
# step 6-1*: use warmup
n_steps
=
mrc_reader
.
num_examples
*
2
*
num_epochs
//
batch_size
n_steps
=
seq_label_reader
.
num_examples
*
1.5
*
num_epochs
//
batch_size
warmup_steps
=
int
(
0.1
*
n_steps
)
sched
=
palm
.
lr_sched
.
TriangularSchedualer
(
warmup_steps
,
n_steps
)
# step 6-2: create a optimizer
...
...
@@ -71,42 +71,13 @@ if __name__ == '__main__':
trainer
.
build_backward
(
optimizer
=
adam
,
weight_decay
=
weight_decay
)
# step 7: fit prepared reader and data
trainer
.
fit_readers_with_mixratio
([
mrc_reader
,
match_reader
],
task_name
,
num_epochs
)
trainer
.
fit_readers_with_mixratio
([
seq_label_reader
,
match_reader
],
"slot"
,
num_epochs
)
# step 8-1*: load pretrained parameters
trainer
.
load_pretrain
(
pre_params
)
# step 8-2*: set saver to save model
save_steps
=
n_steps
-
batch_size
trainer
.
set_saver
(
save_path
=
save_path
,
save_steps
=
save_steps
,
save_type
=
save_type
)
# save_steps = int(n_steps-batch_size)
save_steps
=
10
trainer_seq_label
.
set_saver
(
save_path
=
save_path
,
save_steps
=
save_steps
,
save_type
=
save_type
,
is_multi
=
True
)
# step 8-3: start training
trainer
.
train
(
print_steps
=
print_steps
)
\ No newline at end of file
# ----------------------- for prediction -----------------------
# step 1-1: create readers for prediction
predict_mrc_reader
=
palm
.
reader
.
MRCReader
(
vocab_path
,
max_seqlen
,
max_query_len
,
doc_stride
,
do_lower_case
=
do_lower_case
,
phase
=
'predict'
)
# step 1-2: load the training data
predict_mrc_reader
.
load_data
(
predict_file
,
batch_size
)
# step 2: create a backbone of the model to extract text features
pred_ernie
=
palm
.
backbone
.
ERNIE
.
from_config
(
config
,
phase
=
'predict'
)
# step 3: register the backbone in reader
predict_mrc_reader
.
register_with
(
pred_ernie
)
# step 4: create the task output head
mrc_pred_head
=
palm
.
head
.
MRC
(
max_query_len
,
config
[
'hidden_size'
],
do_lower_case
=
do_lower_case
,
max_ans_len
=
max_ans_len
,
phase
=
'predict'
)
# step 5: build forward graph with backbone and task head
trainer_mrc
.
build_predict_forward
(
pred_ernie
,
mrc_pred_head
)
# step 6: load pretrained model
pred_model_path
=
'./outputs/ckpt.step'
+
str
(
save_steps
)
pred_ckpt
=
trainer_mrc
.
load_ckpt
(
pred_model_path
)
# step 7: fit prepared reader and data
trainer_mrc
.
fit_reader
(
predict_mrc_reader
,
phase
=
'predict'
)
# step 8: predict
print
(
'predicting..'
)
trainer_mrc
.
predict
(
print_steps
=
print_steps
,
output_dir
=
"outputs/"
)
examples/predict/README.md
浏览文件 @
d71b37d0
...
...
@@ -32,7 +32,7 @@ label text_a
### Step 2: Predict
The code used to perform
classification
task is in
`run.py`
. If you have prepared the pre-training model and the data set required for the task, run:
The code used to perform
this
task is in
`run.py`
. If you have prepared the pre-training model and the data set required for the task, run:
```
shell
python run.py
...
...
examples/tagging/README.md
浏览文件 @
d71b37d0
...
...
@@ -34,7 +34,7 @@ text_a label
### Step 2: Train & Predict
The code used to perform
classification
task is in
`run.py`
. If you have prepared the pre-training model and the data set required for the task, run:
The code used to perform
this
task is in
`run.py`
. If you have prepared the pre-training model and the data set required for the task, run:
```
shell
python run.py
...
...
examples/tagging/run.py
浏览文件 @
d71b37d0
...
...
@@ -32,26 +32,26 @@ if __name__ == '__main__':
# ----------------------- for training -----------------------
# step 1-1: create readers for training
ner
_reader
=
palm
.
reader
.
SequenceLabelReader
(
vocab_path
,
max_seqlen
,
label_map
,
seed
=
random_seed
)
seq_label
_reader
=
palm
.
reader
.
SequenceLabelReader
(
vocab_path
,
max_seqlen
,
label_map
,
seed
=
random_seed
)
# step 1-2: load the training data
ner
_reader
.
load_data
(
train_file
,
file_format
=
'tsv'
,
num_epochs
=
num_epochs
,
batch_size
=
batch_size
)
seq_label
_reader
.
load_data
(
train_file
,
file_format
=
'tsv'
,
num_epochs
=
num_epochs
,
batch_size
=
batch_size
)
# step 2: create a backbone of the model to extract text features
ernie
=
palm
.
backbone
.
ERNIE
.
from_config
(
config
)
# step 3: register the backbone in reader
ner
_reader
.
register_with
(
ernie
)
seq_label
_reader
.
register_with
(
ernie
)
# step 4: create the task output head
ner
_head
=
palm
.
head
.
SequenceLabel
(
num_classes
,
input_dim
,
dropout_prob
)
seq_label
_head
=
palm
.
head
.
SequenceLabel
(
num_classes
,
input_dim
,
dropout_prob
)
# step 5-1: create a task trainer
trainer
=
palm
.
Trainer
(
task_name
)
# step 5-2: build forward graph with backbone and task head
loss_var
=
trainer
.
build_forward
(
ernie
,
ner
_head
)
loss_var
=
trainer
.
build_forward
(
ernie
,
seq_label
_head
)
# step 6-1*: use warmup
n_steps
=
ner
_reader
.
num_examples
*
num_epochs
//
batch_size
n_steps
=
seq_label
_reader
.
num_examples
*
num_epochs
//
batch_size
warmup_steps
=
int
(
0.1
*
n_steps
)
print
(
'total_steps: {}'
.
format
(
n_steps
))
print
(
'warmup_steps: {}'
.
format
(
warmup_steps
))
...
...
@@ -62,43 +62,43 @@ if __name__ == '__main__':
trainer
.
build_backward
(
optimizer
=
adam
,
weight_decay
=
weight_decay
)
# step 7: fit prepared reader and data
trainer
.
fit_reader
(
ner
_reader
)
trainer
.
fit_reader
(
seq_label
_reader
)
# step 8-1*: load pretrained parameters
trainer
.
load_pretrain
(
pre_params
)
# step 8-2*: set saver to save model
save_steps
=
(
n_steps
-
20
)
print
(
'save_steps: {}'
.
format
(
save_steps
))
trainer
.
set_saver
(
save_path
=
save_path
,
save_steps
=
save_steps
,
save_type
=
save_type
)
# step 8-3: start training
trainer
.
train
(
print_steps
=
train_print_steps
)
#
#
step 8-1*: load pretrained parameters
#
trainer.load_pretrain(pre_params)
#
#
step 8-2*: set saver to save model
save_steps
=
1951
#
print('save_steps: {}'.format(save_steps))
#
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
#
#
step 8-3: start training
#
trainer.train(print_steps=train_print_steps)
# ----------------------- for prediction -----------------------
# step 1-1: create readers for prediction
print
(
'prepare to predict...'
)
predict_
ner
_reader
=
palm
.
reader
.
SequenceLabelReader
(
vocab_path
,
max_seqlen
,
label_map
,
phase
=
'predict'
)
predict_
seq_label
_reader
=
palm
.
reader
.
SequenceLabelReader
(
vocab_path
,
max_seqlen
,
label_map
,
phase
=
'predict'
)
# step 1-2: load the training data
predict_
ner
_reader
.
load_data
(
predict_file
,
batch_size
)
predict_
seq_label
_reader
.
load_data
(
predict_file
,
batch_size
)
# step 2: create a backbone of the model to extract text features
pred_ernie
=
palm
.
backbone
.
ERNIE
.
from_config
(
config
,
phase
=
'predict'
)
# step 3: register the backbone in reader
predict_
ner
_reader
.
register_with
(
pred_ernie
)
predict_
seq_label
_reader
.
register_with
(
pred_ernie
)
# step 4: create the task output head
ner
_pred_head
=
palm
.
head
.
SequenceLabel
(
num_classes
,
input_dim
,
phase
=
'predict'
)
seq_label
_pred_head
=
palm
.
head
.
SequenceLabel
(
num_classes
,
input_dim
,
phase
=
'predict'
)
# step 5: build forward graph with backbone and task head
trainer
.
build_predict_forward
(
pred_ernie
,
ner
_pred_head
)
trainer
.
build_predict_forward
(
pred_ernie
,
seq_label
_pred_head
)
# step 6: load pretrained model
pred_model_path
=
'./outputs/ckpt.step'
+
str
(
save_steps
)
pred_ckpt
=
trainer
.
load_ckpt
(
pred_model_path
)
# step 7: fit prepared reader and data
trainer
.
fit_reader
(
predict_
ner
_reader
,
phase
=
'predict'
)
trainer
.
fit_reader
(
predict_
seq_label
_reader
,
phase
=
'predict'
)
# step 8: predict
print
(
'predicting..'
)
...
...
paddlepalm/head/mlm.py
浏览文件 @
d71b37d0
...
...
@@ -39,7 +39,6 @@ class MaskLM(Head):
@
property
def
inputs_attrs
(
self
):
reader
=
{
"token_ids"
:[[
-
1
,
-
1
],
'int64'
],
"mask_label"
:
[[
-
1
],
'int64'
],
"mask_pos"
:
[[
-
1
],
'int64'
],
}
...
...
@@ -59,21 +58,19 @@ class MaskLM(Head):
def
build
(
self
,
inputs
,
scope_name
=
""
):
mask_pos
=
inputs
[
"reader"
][
"mask_pos"
]
word_emb
=
inputs
[
"backbone"
][
"embedding_table"
]
enc_out
=
inputs
[
"backbone"
][
"encoder_outputs"
]
if
self
.
_is_training
:
mask_label
=
inputs
[
"reader"
][
"mask_label"
]
l1
=
fluid
.
layers
.
shape
(
inputs
[
"reader"
][
"token_ids"
]
)[
0
]
# bxs = inputs["reader"]["token_ids"].shape[2].value
l2
=
fluid
.
layers
.
shape
(
inputs
[
"reader"
][
"token_ids"
][
0
])[
0
]
bxs
=
(
l1
*
l2
).
astype
(
np
.
int64
)
# max_position = inputs["reader"]["batchsize_x_seqlen"] - 1
l1
=
enc_out
.
shape
[
0
]
l2
=
enc_out
.
shape
[
1
]
bxs
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
value
=
l1
*
l2
,
dtype
=
'int64'
)
max_position
=
bxs
-
1
mask_pos
=
fluid
.
layers
.
elementwise_min
(
mask_pos
,
max_position
)
mask_pos
.
stop_gradient
=
True
word_emb
=
inputs
[
"backbone"
][
"embedding_table"
]
enc_out
=
inputs
[
"backbone"
][
"encoder_outputs"
]
emb_size
=
word_emb
.
shape
[
-
1
]
_param_initializer
=
fluid
.
initializer
.
TruncatedNormal
(
...
...
paddlepalm/multihead_trainer.py
浏览文件 @
d71b37d0
...
...
@@ -52,11 +52,12 @@ class MultiHeadTrainer(Trainer):
'input_varnames'
:
'self._pred_input_varname_list'
,
'fetch_list'
:
'self._pred_fetch_name_list'
}
self
.
_check_save
=
lambda
:
False
#
self._check_save = lambda: False
for
t
in
self
.
_trainers
:
t
.
_set_multitask
()
def
build_forward
(
self
,
backbone
,
heads
):
# def build_forward(self, backbone, heads):
def
build_forward
(
self
):
"""
Build forward computation graph for training, which usually built from input layer to loss node.
...
...
@@ -67,19 +68,12 @@ class MultiHeadTrainer(Trainer):
Return:
- loss_var: a Variable object. The computational graph variable(node) of loss.
"""
if
isinstance
(
heads
,
list
):
head_dict
=
{
k
.
name
:
v
for
k
,
v
in
zip
(
self
.
_trainers
,
heads
)}
elif
isinstance
(
heads
,
dict
):
head_dict
=
heads
else
:
raise
ValueError
()
num_heads
=
len
(
self
.
_trainers
)
assert
len
(
head_dict
)
==
num_heads
for
t
in
self
.
_trainers
:
assert
t
.
name
in
head_dict
,
"expected: {}, exists: {}"
.
format
(
t
.
name
,
head_dict
.
keys
())
head_dict
=
{}
backbone
=
self
.
_trainers
[
0
].
_backbone
for
i
in
self
.
_trainers
:
assert
i
.
_task_head
is
not
None
and
i
.
_backbone
is
not
None
,
"You should build forward for the {} task"
.
format
(
i
.
_name
)
assert
i
.
_backbone
==
backbone
,
"The backbone for each task must be the same"
head_dict
[
i
.
_name
]
=
i
.
_task_head
train_prog
=
fluid
.
Program
()
train_init_prog
=
fluid
.
Program
()
...
...
@@ -88,27 +82,13 @@ class MultiHeadTrainer(Trainer):
def
get_loss
(
i
):
head
=
head_dict
[
self
.
_trainers
[
i
].
name
]
# loss_var = self._trainers[i].build_forward(backbone, head, train_prog, train_init_prog)
loss_var
=
self
.
_trainers
[
i
].
build_forward
(
backbone
,
head
)
return
loss_var
# task_fns = {}
# for i in range(num_heads):
# def task_loss():
# task_id = i
# return lambda: get_loss(task_id)
# task_fns[i] = task_loss()
# task_fns = {i: lambda: get_loss(i) for i in range(num_heads)}
task_fns
=
{
i
:
lambda
i
=
i
:
get_loss
(
i
)
for
i
in
range
(
num_heads
)}
task_fns
=
{
i
:
lambda
i
=
i
:
get_loss
(
i
)
for
i
in
range
(
len
(
self
.
_trainers
))}
with
fluid
.
program_guard
(
train_prog
,
train_init_prog
):
task_id_var
=
fluid
.
data
(
name
=
"__task_id"
,
shape
=
[
1
],
dtype
=
'int64'
)
# task_id_var = fluid.layers.fill_constant(shape=[1],dtype='int64', value=1)
# print(task_id_var.name)
loss_var
=
layers
.
switch_case
(
branch_index
=
task_id_var
,
...
...
@@ -242,7 +222,6 @@ class MultiHeadTrainer(Trainer):
task_rt_outputs
=
{
k
[
len
(
self
.
_trainers
[
task_id
].
name
+
'.'
):]:
v
for
k
,
v
in
rt_outputs
.
items
()
if
k
.
startswith
(
self
.
_trainers
[
task_id
].
name
+
'.'
)}
self
.
_trainers
[
task_id
].
_task_head
.
batch_postprocess
(
task_rt_outputs
)
if
print_steps
>
0
and
self
.
_cur_train_step
%
print_steps
==
0
:
loss
=
rt_outputs
[
self
.
_trainers
[
task_id
].
name
+
'.loss'
]
loss
=
np
.
mean
(
np
.
squeeze
(
loss
)).
tolist
()
...
...
@@ -257,7 +236,7 @@ class MultiHeadTrainer(Trainer):
loss
,
print_steps
/
time_cost
))
time_begin
=
time
.
time
()
self
.
_check_save
()
#
self._check_save()
finish
=
self
.
_check_finish
(
self
.
_trainers
[
task_id
].
name
)
if
finish
:
break
...
...
@@ -287,7 +266,7 @@ class MultiHeadTrainer(Trainer):
rt_outputs
=
self
.
_trainers
[
task_id
].
train_one_step
(
batch
)
self
.
_cur_train_step
+=
1
self
.
_check_save
()
#
self._check_save()
return
rt_outputs
,
task_id
# if dev_count > 1:
...
...
paddlepalm/reader/mlm.py
浏览文件 @
d71b37d0
...
...
@@ -34,7 +34,6 @@ class MaskLMReader(Reader):
for_cn
=
lang
.
lower
()
==
'cn'
or
lang
.
lower
()
==
'chinese'
self
.
_register
.
add
(
'token_ids'
)
self
.
_register
.
add
(
'mask_pos'
)
if
phase
==
'train'
:
self
.
_register
.
add
(
'mask_label'
)
...
...
paddlepalm/reader/utils/reader4ernie.py
浏览文件 @
d71b37d0
...
...
@@ -99,7 +99,7 @@ class Reader(object):
if
label_map_config
:
with
open
(
label_map_config
,
encoding
=
'utf8'
)
as
f
:
self
.
label_map
=
(
f
)
self
.
label_map
=
json
.
load
(
f
)
else
:
self
.
label_map
=
None
...
...
paddlepalm/trainer.py
浏览文件 @
d71b37d0
...
...
@@ -54,6 +54,8 @@ class Trainer(object):
self
.
_train_init
=
False
self
.
_predict_init
=
False
self
.
_train_init_prog
=
None
self
.
_pred_init_prog
=
None
self
.
_check_save
=
lambda
:
False
...
...
@@ -427,6 +429,7 @@ class Trainer(object):
self
.
_pred_feed_batch_process_fn
=
feed_batch_process_fn
# return distribute_feeder_fn()
def
load_ckpt
(
self
,
model_path
):
"""
load training checkpoint for further training or predicting.
...
...
@@ -500,7 +503,7 @@ class Trainer(object):
convert
=
convert
,
main_program
=
self
.
_train_init_prog
)
def
set_saver
(
self
,
save_path
,
save_steps
,
save_type
=
'ckpt'
):
def
set_saver
(
self
,
save_path
,
save_steps
,
save_type
=
'ckpt'
,
is_multi
=
False
):
"""
create a build-in saver into trainer. A saver will automatically save checkpoint or predict model every `save_steps` training steps.
...
...
@@ -511,6 +514,7 @@ class Trainer(object):
"""
save_type
=
save_type
.
split
(
','
)
if
'predict'
in
save_type
:
assert
self
.
_pred_head
is
not
None
,
"Predict head not found! You should build_predict_head first if you want to save predict model."
...
...
@@ -534,10 +538,19 @@ class Trainer(object):
def
temp_func
():
if
(
self
.
_save_predict
or
self
.
_save_ckpt
)
and
self
.
_cur_train_step
%
save_steps
==
0
:
if
self
.
_save_predict
:
if
is_multi
:
self
.
_save
(
save_path
,
suffix
=
'-pred.step'
+
str
(
self
.
_cur_train_step
))
print
(
'predict model has been saved at '
+
os
.
path
.
join
(
save_path
,
'pred.step'
+
str
(
self
.
_cur_train_step
)))
else
:
self
.
_save
(
save_path
,
suffix
=
'pred.step'
+
str
(
self
.
_cur_train_step
))
print
(
'predict model has been saved at '
+
os
.
path
.
join
(
save_path
,
'pred.step'
+
str
(
self
.
_cur_train_step
)))
if
self
.
_save_ckpt
:
if
is_multi
:
fluid
.
io
.
save_persistables
(
self
.
_exe
,
os
.
path
.
join
(
save_path
,
'ckpt.step'
+
str
(
self
.
_cur_train_step
)),
self
.
_train_prog
)
print
(
'checkpoint has been saved at '
+
os
.
path
.
join
(
save_path
,
'ckpt.step'
+
str
(
self
.
_cur_train_step
)))
else
:
fluid
.
io
.
save_persistables
(
self
.
_exe
,
os
.
path
.
join
(
save_path
,
'ckpt.step'
+
str
(
self
.
_cur_train_step
)),
self
.
_train_prog
)
print
(
'checkpoint has been saved at '
+
os
.
path
.
join
(
save_path
,
'ckpt.step'
+
str
(
self
.
_cur_train_step
)))
return
True
...
...
@@ -600,7 +613,7 @@ class Trainer(object):
(
self
.
_cur_train_step
-
1
)
%
self
.
_steps_pur_epoch
+
1
,
self
.
_steps_pur_epoch
,
self
.
_cur_train_epoch
,
loss
,
print_steps
/
time_cost
))
time_begin
=
time
.
time
()
self
.
_check_save
()
#
self._check_save()
# if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps:
# print(cur_task.name+': train finished!')
# cur_task.save()
...
...
@@ -727,6 +740,7 @@ class Trainer(object):
rt_outputs
=
{
k
:
v
for
k
,
v
in
zip
(
self
.
_fetch_names
,
rt_outputs
)}
self
.
_cur_train_step
+=
1
self
.
_check_save
()
self
.
_cur_train_epoch
=
(
self
.
_cur_train_step
-
1
)
//
self
.
_steps_pur_epoch
return
rt_outputs
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录