Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
c04554d9
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c04554d9
编写于
9月 02, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add shrink
上级
73429ba9
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
26 addition
and
0 deletion
+26
-0
paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h
.../train/custom_trainer/feed/accessor/input_data_accessor.h
+4
-0
paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc
...ain/custom_trainer/feed/accessor/sparse_input_accessor.cc
+7
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+15
-0
未找到文件。
paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h
浏览文件 @
c04554d9
...
...
@@ -26,6 +26,10 @@ public:
virtual
int32_t
create
(
::
paddle
::
framework
::
Scope
*
scope
)
{
return
0
;
}
// 裁剪,用于模型裁剪,base级调用
virtual
int32_t
shrink
()
{
return
0
;
}
// 前向, 一般用于填充输入,在训练网络执行前调用
virtual
int32_t
forward
(
SampleInstance
*
samples
,
size_t
num
,
...
...
paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc
浏览文件 @
c04554d9
...
...
@@ -253,6 +253,13 @@ public:
var_data
[
i
]
+=
pull_raw
[
i
+
2
];
}
}
// 裁剪,用于模型裁剪,base级调用
virtual
int32_t
shrink
()
{
auto
*
ps_client
=
_trainer_context
->
pslib
->
ps_client
();
auto
status
=
ps_client
->
shrink
(
_table_id
);
return
status
.
get
();
}
virtual
void
post_process_input
(
float
*
var_data
,
SparseInputVariable
&
variable
,
SampleInstance
*
samples
,
size_t
num
)
{
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
浏览文件 @
c04554d9
...
...
@@ -169,8 +169,23 @@ int LearnerProcess::run() {
//Step3. Dump Model For Delta&&Checkpoint
{
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveTrainCheckpoint
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
if
(
epoch_accessor
->
is_last_epoch
(
epoch_id
)
&&
environment
->
is_master_node
(
EnvironmentRole
::
WORKER
))
{
paddle
::
platform
::
Timer
timer
;
timer
.
Start
();
VLOG
(
2
)
<<
"Start shrink table"
;
for
(
auto
&
executor
:
_executors
)
{
const
auto
&
table_accessors
=
executor
->
table_accessors
();
for
(
auto
&
itr
:
table_accessors
)
{
CHECK
(
itr
.
second
[
0
]
->
shrink
()
==
0
);
}
}
VLOG
(
2
)
<<
"End shrink table, cost"
<<
timer
.
ElapsedSec
();
}
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
epoch_accessor
->
epoch_done
(
epoch_id
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录