Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleRec
提交
f5af6905
P
PaddleRec
项目概览
PaddlePaddle
/
PaddleRec
通知
68
Star
12
Fork
5
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
27
列表
看板
标记
里程碑
合并请求
10
Wiki
1
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
27
Issue
27
列表
看板
标记
里程碑
合并请求
10
合并请求
10
Pages
分析
分析
仓库分析
DevOps
Wiki
1
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f5af6905
编写于
9月 17, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sparse cache
上级
229964e4
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
141 addition
and
5 deletion
+141
-5
BCLOUD
BCLOUD
+1
-0
paddle/fluid/train/custom_trainer/feed/main.cc
paddle/fluid/train/custom_trainer/feed/main.cc
+1
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+83
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
...fluid/train/custom_trainer/feed/process/learner_process.h
+5
-1
paddle/fluid/train/custom_trainer/feed/trainer_context.h
paddle/fluid/train/custom_trainer/feed/trainer_context.h
+51
-4
未找到文件。
BCLOUD
浏览文件 @
f5af6905
...
@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch')
...
@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch')
CONFIGS('baidu/third-party/python@gcc482output@git_branch')
CONFIGS('baidu/third-party/python@gcc482output@git_branch')
CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag')
CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag')
CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch')
CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch')
CONFIGS('baidu/feed-mlarch/hopscotch-map@stable')
CONFIGS('baidu/paddlepaddle/pslib@stable')
CONFIGS('baidu/paddlepaddle/pslib@stable')
CONFIGS('third-64/gtest@gtest_1-7-0-100_PD_BL')
CONFIGS('third-64/gtest@gtest_1-7-0-100_PD_BL')
HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/')
HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/')
...
...
paddle/fluid/train/custom_trainer/feed/main.cc
浏览文件 @
f5af6905
...
@@ -21,6 +21,7 @@ int main(int argc, char* argv[]) {
...
@@ -21,6 +21,7 @@ int main(int argc, char* argv[]) {
//load trainer config
//load trainer config
auto
trainer_context_ptr
=
std
::
make_shared
<
TrainerContext
>
();
auto
trainer_context_ptr
=
std
::
make_shared
<
TrainerContext
>
();
trainer_context_ptr
->
cache_dict
.
reset
(
new
SignCacheDict
);
trainer_context_ptr
->
trainer_config
=
YAML
::
LoadFile
(
FLAGS_feed_trainer_conf_path
);
trainer_context_ptr
->
trainer_config
=
YAML
::
LoadFile
(
FLAGS_feed_trainer_conf_path
);
//environment
//environment
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
浏览文件 @
f5af6905
...
@@ -16,6 +16,8 @@ namespace feed {
...
@@ -16,6 +16,8 @@ namespace feed {
int
LearnerProcess
::
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
int
LearnerProcess
::
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
int
ret
=
Process
::
initialize
(
context_ptr
);
int
ret
=
Process
::
initialize
(
context_ptr
);
auto
&
config
=
_context_ptr
->
trainer_config
;
auto
&
config
=
_context_ptr
->
trainer_config
;
_is_dump_cache_model
=
config
[
"dump_cache_model"
].
as
<
bool
>
(
false
);
_cache_load_converter
=
config
[
"load_cache_converter"
].
as
<
std
::
string
>
(
""
);
_startup_dump_inference_base
=
config
[
"startup_dump_inference_base"
].
as
<
bool
>
(
false
);
_startup_dump_inference_base
=
config
[
"startup_dump_inference_base"
].
as
<
bool
>
(
false
);
if
(
config
[
"executor"
])
{
if
(
config
[
"executor"
])
{
_executors
.
resize
(
config
[
"executor"
].
size
());
_executors
.
resize
(
config
[
"executor"
].
size
());
...
@@ -27,6 +29,53 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
...
@@ -27,6 +29,53 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
return
0
;
return
0
;
}
}
// 更新各节点存储的CacheModel
int
LearnerProcess
::
update_cache_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
)
{
auto
fs
=
_context_ptr
->
file_system
;
auto
*
ps_client
=
_context_ptr
->
pslib
->
ps_client
();
auto
*
environment
=
_context_ptr
->
environment
.
get
();
auto
*
epoch_accessor
=
_context_ptr
->
epoch_accessor
.
get
();
if
(
!
epoch_accessor
->
need_save_model
(
epoch_id
,
way
))
{
return
0
;
}
auto
*
ps_param
=
_context_ptr
->
pslib
->
get_param
();
if
(
_is_dump_cache_model
&&
way
==
ModelSaveWay
::
ModelSaveInferenceBase
)
{
auto
model_dir
=
epoch_accessor
->
model_save_path
(
epoch_id
,
way
);
auto
&
table_param
=
ps_param
->
server_param
().
downpour_server_param
().
downpour_table_param
();
for
(
auto
&
param
:
table_param
)
{
if
(
param
.
type
()
!=
paddle
::
PS_SPARSE_TABLE
)
{
continue
;
}
auto
cache_model_path
=
fs
->
path_join
(
model_dir
,
string
::
format_string
(
"%03d_cache/"
,
param
.
table_id
()));
if
(
!
fs
->
exists
(
cache_model_path
))
{
continue
;
}
auto
&
cache_dict
=
*
(
_context_ptr
->
cache_dict
.
get
());
cache_dict
.
clear
();
cache_dict
.
reserve
(
_cache_sign_max_num
);
auto
cache_file_list
=
fs
->
list
(
fs
->
path_join
(
cache_model_path
,
"part*"
));
for
(
auto
&
cache_path
:
cache_file_list
)
{
auto
cache_file
=
fs
->
open_read
(
cache_path
,
_cache_load_converter
);
char
*
buffer
=
nullptr
;
size_t
buffer_size
=
0
;
ssize_t
line_len
=
0
;
while
((
line_len
=
getline
(
&
buffer
,
&
buffer_size
,
cache_file
.
get
()))
!=
-
1
)
{
if
(
line_len
<=
1
)
{
continue
;
}
char
*
data_ptr
;
cache_dict
.
append
(
strtoul
(
buffer
,
&
data_ptr
,
10
));
}
if
(
buffer
!=
nullptr
)
{
free
(
buffer
);
}
}
break
;
}
}
return
0
;
}
int
LearnerProcess
::
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
)
{
int
LearnerProcess
::
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
)
{
auto
fs
=
_context_ptr
->
file_system
;
auto
fs
=
_context_ptr
->
file_system
;
auto
*
ps_client
=
_context_ptr
->
pslib
->
ps_client
();
auto
*
ps_client
=
_context_ptr
->
pslib
->
ps_client
();
...
@@ -65,7 +114,38 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is
...
@@ -65,7 +114,38 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is
}
}
timer
.
Pause
();
timer
.
Pause
();
VLOG
(
2
)
<<
"Save Model Cost(s):"
<<
timer
.
ElapsedSec
();
VLOG
(
2
)
<<
"Save Model Cost(s):"
<<
timer
.
ElapsedSec
();
// save cache model, 只有inference需要cache_model
auto
*
ps_param
=
_context_ptr
->
pslib
->
get_param
();
if
(
_is_dump_cache_model
&&
(
way
==
ModelSaveWay
::
ModelSaveInferenceBase
||
way
==
ModelSaveWay
::
ModelSaveInferenceDelta
))
{
auto
&
table_param
=
ps_param
->
server_param
().
downpour_server_param
().
downpour_table_param
();
for
(
auto
&
param
:
table_param
)
{
if
(
param
.
type
()
!=
paddle
::
PS_SPARSE_TABLE
)
{
continue
;
}
double
cache_threshold
=
0.0
;
auto
status
=
ps_client
->
get_cache_threshold
(
param
.
table_id
(),
cache_threshold
);
CHECK
(
status
.
get
()
==
0
)
<<
"CacheThreshold Get failed!"
;
status
=
ps_client
->
cache_shuffle
(
param
.
table_id
(),
model_dir
,
std
::
to_string
((
int
)
way
),
std
::
to_string
(
cache_threshold
));
CHECK
(
status
.
get
()
==
0
)
<<
"Cache Shuffler Failed"
;
status
=
ps_client
->
save_cache
(
param
.
table_id
(),
model_dir
,
std
::
to_string
((
int
)
way
));
auto
feature_size
=
status
.
get
();
CHECK
(
feature_size
>=
0
)
<<
"Cache Save Failed"
;
auto
cache_model_path
=
fs
->
path_join
(
model_dir
,
string
::
format_string
(
"%03d_cache/sparse_cache.meta"
,
param
.
table_id
()));
auto
cache_meta_file
=
fs
->
open_write
(
cache_model_path
,
""
);
auto
meta
=
string
::
format_string
(
"file_prefix:part
\n
part_num:%d
\n
key_num:%d
\n
"
,
param
.
sparse_table_cache_file_num
(),
feature_size
);
CHECK
(
fwrite
(
meta
.
c_str
(),
meta
.
size
(),
1
,
cache_meta_file
.
get
())
==
1
)
<<
"Cache Meta Failed"
;
if
(
feature_size
>
_cache_sign_max_num
)
{
_cache_sign_max_num
=
feature_size
;
}
}
}
_context_ptr
->
epoch_accessor
->
update_model_donefile
(
epoch_id
,
way
);
_context_ptr
->
epoch_accessor
->
update_model_donefile
(
epoch_id
,
way
);
return
all_ret
;
return
all_ret
;
}
}
...
@@ -176,6 +256,9 @@ int LearnerProcess::run() {
...
@@ -176,6 +256,9 @@ int LearnerProcess::run() {
{
{
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
update_cache_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
if
(
epoch_accessor
->
is_last_epoch
(
epoch_id
))
{
if
(
epoch_accessor
->
is_last_epoch
(
epoch_id
))
{
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveTrainCheckpointBase
);
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveTrainCheckpointBase
);
}
else
{
}
else
{
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
浏览文件 @
f5af6905
...
@@ -22,9 +22,13 @@ protected:
...
@@ -22,9 +22,13 @@ protected:
virtual
int
load_model
(
uint64_t
epoch_id
);
virtual
int
load_model
(
uint64_t
epoch_id
);
// 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型
// 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型
virtual
int
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
=
false
);
virtual
int
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
=
false
);
virtual
int
update_cache_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
);
private:
private:
bool
_startup_dump_inference_base
;
//启动立即dump base
bool
_is_dump_cache_model
;
// 是否进行cache dump
uint32_t
_cache_sign_max_num
=
0
;
// cache sign最大个数
std
::
string
_cache_load_converter
;
// cache加载的前置转换脚本
bool
_startup_dump_inference_base
;
// 启动立即dump base
std
::
vector
<
std
::
shared_ptr
<
MultiThreadExecutor
>>
_executors
;
std
::
vector
<
std
::
shared_ptr
<
MultiThreadExecutor
>>
_executors
;
};
};
...
...
paddle/fluid/train/custom_trainer/feed/trainer_context.h
浏览文件 @
f5af6905
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include <sstream>
#include <sstream>
#include <tsl/bhopscotch_map.h>
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
...
@@ -35,15 +36,61 @@ enum class TrainerStatus {
...
@@ -35,15 +36,61 @@ enum class TrainerStatus {
Saving
=
1
// 模型存储状态
Saving
=
1
// 模型存储状态
};
};
const
uint32_t
SignCacheMaxValueNum
=
13
;
struct
SignCacheData
{
SignCacheData
()
{
memset
(
cache_value
,
0
,
sizeof
(
float
)
*
SignCacheMaxValueNum
);
}
uint32_t
idx
;
float
cache_value
[
SignCacheMaxValueNum
];
};
class
SignCacheDict
{
class
SignCacheDict
{
public:
public:
int32_t
sign2index
(
uint64_t
sign
)
{
inline
int32_t
sign2index
(
uint64_t
sign
)
{
return
-
1
;
auto
itr
=
_sign2data_map
.
find
(
sign
);
if
(
itr
==
_sign2data_map
.
end
())
{
return
-
1
;
}
return
itr
->
second
.
idx
;
}
inline
uint64_t
index2sign
(
int32_t
index
)
{
if
(
index
>=
_sign_list
.
size
())
{
return
0
;
}
return
_sign_list
[
index
];
}
inline
void
reserve
(
uint32_t
size
)
{
_sign_list
.
reserve
(
size
);
_sign2data_map
.
reserve
(
size
);
}
inline
void
clear
()
{
_sign_list
.
clear
();
_sign2data_map
.
clear
();
}
inline
void
append
(
uint64_t
sign
)
{
if
(
_sign2data_map
.
find
(
sign
)
!=
_sign2data_map
.
end
())
{
return
;
}
SignCacheData
data
;
data
.
idx
=
_sign_list
.
size
();
_sign_list
.
push_back
(
sign
);
_sign2data_map
.
emplace
(
sign
,
std
::
move
(
data
));
}
}
uint64_t
index2sign
(
int32_t
index
)
{
inline
SignCacheData
*
data
(
uint64_t
sign
)
{
return
0
;
tsl
::
bhopscotch_pg_map
<
uint64_t
,
SignCacheData
>::
iterator
itr
=
_sign2data_map
.
find
(
sign
);
if
(
itr
==
_sign2data_map
.
end
())
{
return
nullptr
;
}
return
const_cast
<
SignCacheData
*>
(
&
(
itr
->
second
));
}
}
private:
std
::
vector
<
uint64_t
>
_sign_list
;
tsl
::
bhopscotch_pg_map
<
uint64_t
,
SignCacheData
>
_sign2data_map
;
};
};
class
TrainerContext
{
class
TrainerContext
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录