Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PALM
提交
bf089579
P
PALM
项目概览
PaddlePaddle
/
PALM
通知
5
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
10
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PALM
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
10
Issue
10
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bf089579
编写于
12月 26, 2019
作者:
X
xixiaoyao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix load pretrain
上级
af31077d
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
135 addition
and
6 deletion
+135
-6
demo/demo3/run.py
demo/demo3/run.py
+1
-1
paddlepalm/distribute/__init__.py
paddlepalm/distribute/__init__.py
+9
-0
paddlepalm/distribute/reader.py
paddlepalm/distribute/reader.py
+109
-0
paddlepalm/trainer.py
paddlepalm/trainer.py
+14
-4
paddlepalm/utils/.saver.py.swp
paddlepalm/utils/.saver.py.swp
+0
-0
paddlepalm/utils/basic_helper.py
paddlepalm/utils/basic_helper.py
+1
-0
paddlepalm/utils/saver.py
paddlepalm/utils/saver.py
+1
-1
未找到文件。
demo/demo3/run.py
浏览文件 @
bf089579
...
...
@@ -62,7 +62,7 @@ if __name__ == '__main__':
use_ema
=
True
,
ema_decay
=
0.999
)
trainer
.
random_init_params
()
trainer
.
load_pretrain
(
'
../../pretrain_model
/ernie/params'
)
trainer
.
load_pretrain
(
'
pretrain
/ernie/params'
)
# trainer.train_one_step()
# trainer.train_one_epoch()
...
...
paddlepalm/distribute/__init__.py
0 → 100644
浏览文件 @
bf089579
from
paddle
import
fluid
import
os
import
multiprocessing
gpu_dev_count
=
int
(
fluid
.
core
.
get_cuda_device_count
())
cpu_dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
from
reader
import
yield_pieces
,
data_feeder
paddlepalm/distribute/reader.py
0 → 100644
浏览文件 @
bf089579
from
.
import
gpu_dev_count
,
cpu_dev_count
import
Queue
from
threading
import
Thread
dev_count
=
gpu_dev_count
if
gpu_dev_count
>
0
else
cpu_dev_count
def
yield_pieces
(
data
,
distribute_strategy
,
batch_size
):
"""
Args:
distribute_strategy: support s=split, c=copy, u=unstack,
"""
assert
batch_size
%
dev_count
==
0
,
"batch_size need to be integer times larger than dev_count."
print
(
'data in yield pieces'
)
print
(
len
(
data
))
assert
type
(
data
)
==
type
(
distribute_strategy
),
[
type
(
data
),
type
(
distribute_strategy
)]
assert
len
(
data
)
==
len
(
distribute_strategy
),
[
len
(
data
),
len
(
distribute_strategy
)]
if
isinstance
(
data
,
dict
):
keys
=
list
(
data
.
keys
())
data_list
=
[
data
[
i
]
for
i
in
keys
]
ds_list
=
[
distribute_strategy
[
i
]
for
i
in
keys
]
else
:
assert
isinstance
(
data
,
list
),
"the input data must be a list or dict, and contained with multiple tensors."
data_list
=
data
ds_list
=
distribute_strategy
stride
=
batch_size
//
dev_count
p
=
stride
# while p < len(data_list) + stride:
while
p
<=
batch_size
:
temp
=
[]
for
d
,
s
in
zip
(
data_list
,
ds_list
):
s
=
s
.
strip
().
lower
()
if
s
==
's'
or
s
==
'split'
:
if
p
-
stride
>=
len
(
d
):
print
(
'WARNING: no more examples to feed empty devices'
)
temp
=
[]
return
temp
.
append
(
d
[
p
-
stride
:
p
])
elif
s
==
'u'
or
s
==
'unstack'
:
assert
len
(
d
)
<=
dev_count
,
'Tensor size on dim 0 must be less equal to dev_count when unstack is applied.'
if
p
//
stride
>
len
(
d
):
print
(
'WARNING: no more examples to feed empty devices'
)
return
temp
.
append
(
d
[
p
//
stride
-
1
])
elif
s
==
'c'
or
s
==
'copy'
:
temp
.
append
(
d
)
else
:
raise
NotImplementedError
()
p
+=
stride
if
type
(
data
)
==
dict
:
yield
dict
(
zip
(
*
[
keys
,
temp
]))
else
:
print
(
'yielded pieces'
)
print
(
len
(
temp
))
yield
temp
def
data_feeder
(
reader
,
postprocess_fn
=
None
,
prefetch_steps
=
2
):
if
postprocess_fn
is
None
:
def
postprocess_fn
(
batch
):
return
batch
def
worker
(
reader
,
dev_count
,
queue
):
dev_batches
=
[]
for
index
,
data
in
enumerate
(
reader
()):
if
len
(
dev_batches
)
<
dev_count
:
dev_batches
.
append
(
data
)
if
len
(
dev_batches
)
==
dev_count
:
queue
.
put
((
dev_batches
,
0
))
dev_batches
=
[]
# For the prediction of the remained batches, pad more batches to
# the number of devices and the padded samples would be removed in
# prediction outputs.
if
len
(
dev_batches
)
>
0
:
num_pad
=
dev_count
-
len
(
dev_batches
)
for
i
in
range
(
len
(
dev_batches
),
dev_count
):
dev_batches
.
append
(
dev_batches
[
-
1
])
queue
.
put
((
dev_batches
,
num_pad
))
queue
.
put
(
None
)
queue
=
Queue
.
Queue
(
dev_count
*
prefetch_steps
)
p
=
Thread
(
target
=
worker
,
args
=
(
reader
,
dev_count
,
queue
))
p
.
daemon
=
True
p
.
start
()
while
True
:
ret
=
queue
.
get
()
queue
.
task_done
()
if
ret
is
not
None
:
batches
,
num_pad
=
ret
batch_buf
=
[]
flag_buf
=
[]
for
idx
,
batch
in
enumerate
(
batches
):
# flag = num_pad == 0
flag
=
idx
-
len
(
batches
)
<
-
num_pad
# if num_pad > 0:
# num_pad -= 1
batch
=
postprocess_fn
(
batch
)
batch_buf
.
append
(
batch
)
flag_buf
.
append
(
flag
)
yield
batch_buf
,
flag_buf
else
:
break
queue
.
join
()
paddlepalm/trainer.py
浏览文件 @
bf089579
...
...
@@ -18,7 +18,8 @@ import os
import
json
from
paddle
import
fluid
import
paddlepalm.utils.basic_helper
as
helper
from
paddlepalm.utils
import
reader_helper
from
paddlepalm.utils
import
reader_helper
,
saver
from
paddlepalm.distribute
import
gpu_dev_count
# from paddlepalm.default_settings import *
DEBUG
=
False
...
...
@@ -79,7 +80,7 @@ class Trainer(object):
self
.
_pred_fetch_name_list
=
[]
self
.
_pred_fetch_var_list
=
[]
self
.
_exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
self
.
_exe
=
None
self
.
_save_protocol
=
{
'input_names'
:
'self._pred_input_name_list'
,
...
...
@@ -256,8 +257,17 @@ class Trainer(object):
return
iterator_fn
def
random_init_params
(
self
):
helper
.
build_executor
()
on_gpu
=
gpu_dev_count
>
0
self
.
_exe
=
helper
.
build_executor
(
on_gpu
)
def
load_pretrain
(
self
,
model_path
):
# load pretrain model (or ckpt)
assert
self
.
_exe
is
not
None
,
"You need to random_init_params before load pretrain models."
saver
.
init_pretraining_params
(
self
.
_exe
,
model_path
,
main_program
=
self
.
_train_init_prog
)
def
_build_head
(
self
,
net_inputs
,
phase
,
scope
=
""
):
if
phase
==
'train'
:
...
...
paddlepalm/utils/.saver.py.swp
0 → 100644
浏览文件 @
bf089579
文件已添加
paddlepalm/utils/basic_helper.py
浏览文件 @
bf089579
...
...
@@ -3,6 +3,7 @@ import os
import
json
import
yaml
from
config_helper
import
PDConfig
from
paddle
import
fluid
def
get_basename
(
f
):
return
os
.
path
.
splitext
(
f
)[
0
]
...
...
paddlepalm/utils/saver.py
浏览文件 @
bf089579
...
...
@@ -55,7 +55,7 @@ def init_pretraining_params(exe,
print
(
"Loading pretraining parameters from {}..."
.
format
(
pretraining_params_path
))
with
tarfile
.
open
(
os
.
path
.
join
(
pretraining_params_path
,
'__palmmodel__'
),
'r
:
'
)
as
f
:
with
tarfile
.
open
(
os
.
path
.
join
(
pretraining_params_path
,
'__palmmodel__'
),
'r'
)
as
f
:
f
.
extractall
(
os
.
path
.
join
(
pretraining_params_path
,
'.temp'
))
log_path
=
os
.
path
.
join
(
pretraining_params_path
,
'__palmmodel__'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录