Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
085a13d2
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
8
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
085a13d2
编写于
1月 10, 2020
作者:
X
xixiaoyao
浏览文件
操作
浏览文件
下载
差异文件
fix pred
上级
df98c24f
ebe2d6ad
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
41 addition
and
20 deletion
+41
-20
paddlepalm/trainer.py
paddlepalm/trainer.py
+28
-11
paddlepalm/utils/saver.py
paddlepalm/utils/saver.py
+13
-9
未找到文件。
paddlepalm/trainer.py
浏览文件 @
085a13d2
...
@@ -40,6 +40,9 @@ class Trainer(object):
...
@@ -40,6 +40,9 @@ class Trainer(object):
self
.
_task_head
=
task_head
self
.
_task_head
=
task_head
self
.
_pred_head
=
None
self
.
_pred_head
=
None
self
.
_train_init
=
False
self
.
_predict_init
=
False
# if save_predict_model:
# if save_predict_model:
# self._save_predict_model = True
# self._save_predict_model = True
# assert pred_head is not None, "pred_head is required to save predict model."
# assert pred_head is not None, "pred_head is required to save predict model."
...
@@ -220,7 +223,7 @@ class Trainer(object):
...
@@ -220,7 +223,7 @@ class Trainer(object):
for
_id
,
block
in
enumerate
(
self
.
_train_prog
.
blocks
):
for
_id
,
block
in
enumerate
(
self
.
_train_prog
.
blocks
):
for
var
in
block
.
vars
:
for
var
in
block
.
vars
:
print
(
"[debug] : %d, %s"
%
(
_id
,
var
))
print
(
"[debug] : %d, %s"
%
(
_id
,
var
))
self
.
_loss_var
=
loss_var
return
loss_var
return
loss_var
def
build_backward
(
self
,
optimizer
,
weight_decay
=
None
,
use_ema
=
False
,
ema_decay
=
0.9999
):
def
build_backward
(
self
,
optimizer
,
weight_decay
=
None
,
use_ema
=
False
,
ema_decay
=
0.9999
):
...
@@ -296,30 +299,44 @@ class Trainer(object):
...
@@ -296,30 +299,44 @@ class Trainer(object):
distribute_feeder_fn
=
iterator_fn
distribute_feeder_fn
=
iterator_fn
return
distribute_feeder_fn
()
return
distribute_feeder_fn
()
def
random_init_params
(
self
):
def
_init_exe_prog
(
self
,
for_train
=
True
):
assert
self
.
_train_init_prog
is
not
None
,
"train graph not foung! You should build_forward first before you random init parameters."
assert
self
.
_train_init_prog
is
not
None
,
"train graph not foung! You should build_forward first before you random init parameters."
self
.
_distribute_train_prog
=
fluid
.
CompiledProgram
(
self
.
_train_prog
).
with_data_parallel
(
loss_name
=
loss_var
.
name
)
self
.
_train_init
=
True
self
.
_distribute_train_prog
=
fluid
.
CompiledProgram
(
self
.
_train_prog
).
with_data_parallel
(
loss_name
=
self
.
_loss_var
.
name
)
on_gpu
=
gpu_dev_count
>
0
on_gpu
=
gpu_dev_count
>
0
self
.
_exe
=
helper
.
build_executor
(
on_gpu
)
self
.
_exe
=
helper
.
build_executor
(
on_gpu
)
if
not
for_train
:
raise
NotImplementedError
()
def
random_init_params
(
self
):
if
not
self
.
_train_init
:
self
.
_init_exe_prog
()
print
(
'random init params...'
)
print
(
'random init params...'
)
self
.
_exe
.
run
(
self
.
_train_init_prog
)
self
.
_exe
.
run
(
self
.
_train_init_prog
)
def
load_ckpt
(
self
,
model_path
,
phase
=
'train'
):
def
load_ckpt
(
self
,
model_path
,
phase
=
'train'
):
# load pretrain model (or ckpt)
# load pretrain model (or ckpt)
assert
self
.
_exe
is
not
None
,
"You need to random_init_params before load checkpoints."
# assert self._exe is not None, "You need to random_init_params before load checkpoints."
if
phase
==
'train'
and
not
self
.
_train_init
:
self
.
_init_exe_prog
()
if
phase
==
'predict'
and
not
self
.
_predict_init
:
pass
if
phase
==
'train'
:
if
phase
==
'train'
:
assert
self
.
_train_init_prog
is
not
None
,
"train graph not found! You should build_forward first before load checkpoint."
assert
self
.
_train_init_prog
is
not
None
,
"train graph not found! You should build_forward first before load checkpoint."
saver
.
init_pretraining_params
(
saver
.
init_pretraining_params
(
self
.
_exe
,
self
.
_exe
,
model_path
,
model_path
,
main_program
=
self
.
_train_init_prog
)
main_program
=
self
.
_train_init_prog
,
strict
=
True
)
elif
phase
==
'predict'
:
elif
phase
==
'predict'
:
assert
self
.
_pred_init_prog
is
not
None
,
"predict graph not found! You should build_predict_head first before load checkpoint."
assert
self
.
_pred_init_prog
is
not
None
,
"predict graph not found! You should build_predict_head first before load checkpoint."
saver
.
init_pretraining_params
(
saver
.
init_pretraining_params
(
self
.
_exe
,
self
.
_exe
,
model_path
,
model_path
,
main_program
=
self
.
_pred_init_prog
)
main_program
=
self
.
_pred_init_prog
,
strict
=
True
)
else
:
else
:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -397,6 +414,11 @@ class Trainer(object):
...
@@ -397,6 +414,11 @@ class Trainer(object):
task_rt_outputs
=
{
k
[
len
(
self
.
name
+
'.'
):]:
v
for
k
,
v
in
rt_outputs
.
items
()
if
k
.
startswith
(
self
.
name
+
'.'
)}
task_rt_outputs
=
{
k
[
len
(
self
.
name
+
'.'
):]:
v
for
k
,
v
in
rt_outputs
.
items
()
if
k
.
startswith
(
self
.
name
+
'.'
)}
self
.
_task_head
.
postprocess
(
task_rt_outputs
)
self
.
_task_head
.
postprocess
(
task_rt_outputs
)
# rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
task_rt_outputs
=
{
k
[
len
(
self
.
name
+
'.'
):]:
v
for
k
,
v
in
rt_outputs
.
items
()
if
k
.
startswith
(
self
.
name
+
'.'
)}
self
.
_task_head
.
postprocess
(
task_rt_outputs
)
self
.
_cur_train_step
+=
1
self
.
_cur_train_step
+=
1
self
.
_cur_train_epoch
=
(
self
.
_cur_train_step
-
1
)
//
self
.
_steps_pur_epoch
self
.
_cur_train_epoch
=
(
self
.
_cur_train_step
-
1
)
//
self
.
_steps_pur_epoch
...
@@ -578,11 +600,6 @@ class Trainer(object):
...
@@ -578,11 +600,6 @@ class Trainer(object):
# self._cur_train_step = 1
# self._cur_train_step = 1
# if self._is_target and self._cur_train_step + self._cur_train_epoch * self._steps_pur_epoch >= self._expected_train_steps:
# if self._is_target and self._cur_train_step + self._cur_train_epoch * self._steps_pur_epoch >= self._expected_train_steps:
# self._train_finish = True
# self._train_finish = True
@
property
def
steps_pur_epoch
(
self
):
return
self
.
_steps_pur_epoch
@
steps_pur_epoch
.
setter
@
steps_pur_epoch
.
setter
def
steps_pur_epoch
(
self
,
value
):
def
steps_pur_epoch
(
self
,
value
):
self
.
_steps_pur_epoch
=
value
self
.
_steps_pur_epoch
=
value
...
...
paddlepalm/utils/saver.py
浏览文件 @
085a13d2
...
@@ -46,20 +46,24 @@ def init_checkpoint(exe, init_checkpoint_path, main_program, skip_list = []):
...
@@ -46,20 +46,24 @@ def init_checkpoint(exe, init_checkpoint_path, main_program, skip_list = []):
def
init_pretraining_params
(
exe
,
def
init_pretraining_params
(
exe
,
pretraining_params_path
,
pretraining_params_path
,
convert
,
main_program
):
main_program
):
assert
os
.
path
.
exists
(
pretraining_params_path
assert
os
.
path
.
exists
(
pretraining_params_path
),
"[%s] cann't be found."
%
pretraining_params_path
),
"[%s] cann't be found."
%
pretraining_params_path
if
convert
:
assert
os
.
path
.
exists
(
os
.
path
.
join
(
pretraining_params_path
,
'__palmmodel__'
)),
"__palmmodel__ not found."
assert
os
.
path
.
exists
(
os
.
path
.
join
(
pretraining_params_path
,
'__palmmodel__'
)),
"__palmmodel__ not found."
with
tarfile
.
open
(
os
.
path
.
join
(
pretraining_params_path
,
'__palmmodel__'
),
'r'
)
as
f
:
print
(
"Loading pretraining parameters from {}..."
.
format
(
f
.
extractall
(
os
.
path
.
join
(
pretraining_params_path
,
'.temp'
))
pretraining_params_path
))
log_path
=
os
.
path
.
join
(
pretraining_params_path
,
'__palmmodel__'
)
pretraining_params_path
=
os
.
path
.
join
(
pretraining_params_path
,
'.temp'
)
with
tarfile
.
open
(
os
.
path
.
join
(
pretraining_params_path
,
'__palmmodel__'
),
'r'
)
as
f
:
else
:
f
.
extractall
(
os
.
path
.
join
(
pretraining_params_path
,
'.temp'
))
log_path
=
pretraining_params_path
log_path
=
os
.
path
.
join
(
pretraining_params_path
,
'__palmmodel__'
)
print
(
"Loading pretraining parameters from {}..."
.
format
(
pretraining_params_path
))
pretraining_params_path
=
os
.
path
.
join
(
pretraining_params_path
,
'.temp'
)
def
existed_params
(
var
):
def
existed_params
(
var
):
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
...
@@ -73,8 +77,8 @@ def init_pretraining_params(exe,
...
@@ -73,8 +77,8 @@ def init_pretraining_params(exe,
pretraining_params_path
,
pretraining_params_path
,
main_program
=
main_program
,
main_program
=
main_program
,
predicate
=
existed_params
)
predicate
=
existed_params
)
if
convert
:
shutil
.
rmtree
(
pretraining_params_path
)
shutil
.
rmtree
(
pretraining_params_path
)
print
(
''
)
print
(
''
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录