Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
PaddleRec
提交
f5af6905
P
PaddleRec
项目概览
BaiXuePrincess
/
PaddleRec
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleRec
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleRec
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f5af6905
编写于
9月 17, 2019
作者:
X
xiexionghang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sparse cache
上级
229964e4
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
141 addition
and
5 deletion
+141
-5
BCLOUD
BCLOUD
+1
-0
paddle/fluid/train/custom_trainer/feed/main.cc
paddle/fluid/train/custom_trainer/feed/main.cc
+1
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
...luid/train/custom_trainer/feed/process/learner_process.cc
+83
-0
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
...fluid/train/custom_trainer/feed/process/learner_process.h
+5
-1
paddle/fluid/train/custom_trainer/feed/trainer_context.h
paddle/fluid/train/custom_trainer/feed/trainer_context.h
+51
-4
未找到文件。
BCLOUD
浏览文件 @
f5af6905
...
@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch')
...
@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch')
CONFIGS('baidu/third-party/python@gcc482output@git_branch')
CONFIGS('baidu/third-party/python@gcc482output@git_branch')
CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag')
CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag')
CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch')
CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch')
CONFIGS('baidu/feed-mlarch/hopscotch-map@stable')
CONFIGS('baidu/paddlepaddle/pslib@stable')
CONFIGS('baidu/paddlepaddle/pslib@stable')
CONFIGS('third-64/gtest@gtest_1-7-0-100_PD_BL')
CONFIGS('third-64/gtest@gtest_1-7-0-100_PD_BL')
HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/')
HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/')
...
...
paddle/fluid/train/custom_trainer/feed/main.cc
浏览文件 @
f5af6905
...
@@ -21,6 +21,7 @@ int main(int argc, char* argv[]) {
...
@@ -21,6 +21,7 @@ int main(int argc, char* argv[]) {
//load trainer config
//load trainer config
auto
trainer_context_ptr
=
std
::
make_shared
<
TrainerContext
>
();
auto
trainer_context_ptr
=
std
::
make_shared
<
TrainerContext
>
();
trainer_context_ptr
->
cache_dict
.
reset
(
new
SignCacheDict
);
trainer_context_ptr
->
trainer_config
=
YAML
::
LoadFile
(
FLAGS_feed_trainer_conf_path
);
trainer_context_ptr
->
trainer_config
=
YAML
::
LoadFile
(
FLAGS_feed_trainer_conf_path
);
//environment
//environment
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.cc
浏览文件 @
f5af6905
...
@@ -16,6 +16,8 @@ namespace feed {
...
@@ -16,6 +16,8 @@ namespace feed {
int
LearnerProcess
::
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
int
LearnerProcess
::
initialize
(
std
::
shared_ptr
<
TrainerContext
>
context_ptr
)
{
int
ret
=
Process
::
initialize
(
context_ptr
);
int
ret
=
Process
::
initialize
(
context_ptr
);
auto
&
config
=
_context_ptr
->
trainer_config
;
auto
&
config
=
_context_ptr
->
trainer_config
;
_is_dump_cache_model
=
config
[
"dump_cache_model"
].
as
<
bool
>
(
false
);
_cache_load_converter
=
config
[
"load_cache_converter"
].
as
<
std
::
string
>
(
""
);
_startup_dump_inference_base
=
config
[
"startup_dump_inference_base"
].
as
<
bool
>
(
false
);
_startup_dump_inference_base
=
config
[
"startup_dump_inference_base"
].
as
<
bool
>
(
false
);
if
(
config
[
"executor"
])
{
if
(
config
[
"executor"
])
{
_executors
.
resize
(
config
[
"executor"
].
size
());
_executors
.
resize
(
config
[
"executor"
].
size
());
...
@@ -27,6 +29,53 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
...
@@ -27,6 +29,53 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
return
0
;
return
0
;
}
}
// 更新各节点存储的CacheModel
int
LearnerProcess
::
update_cache_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
)
{
auto
fs
=
_context_ptr
->
file_system
;
auto
*
ps_client
=
_context_ptr
->
pslib
->
ps_client
();
auto
*
environment
=
_context_ptr
->
environment
.
get
();
auto
*
epoch_accessor
=
_context_ptr
->
epoch_accessor
.
get
();
if
(
!
epoch_accessor
->
need_save_model
(
epoch_id
,
way
))
{
return
0
;
}
auto
*
ps_param
=
_context_ptr
->
pslib
->
get_param
();
if
(
_is_dump_cache_model
&&
way
==
ModelSaveWay
::
ModelSaveInferenceBase
)
{
auto
model_dir
=
epoch_accessor
->
model_save_path
(
epoch_id
,
way
);
auto
&
table_param
=
ps_param
->
server_param
().
downpour_server_param
().
downpour_table_param
();
for
(
auto
&
param
:
table_param
)
{
if
(
param
.
type
()
!=
paddle
::
PS_SPARSE_TABLE
)
{
continue
;
}
auto
cache_model_path
=
fs
->
path_join
(
model_dir
,
string
::
format_string
(
"%03d_cache/"
,
param
.
table_id
()));
if
(
!
fs
->
exists
(
cache_model_path
))
{
continue
;
}
auto
&
cache_dict
=
*
(
_context_ptr
->
cache_dict
.
get
());
cache_dict
.
clear
();
cache_dict
.
reserve
(
_cache_sign_max_num
);
auto
cache_file_list
=
fs
->
list
(
fs
->
path_join
(
cache_model_path
,
"part*"
));
for
(
auto
&
cache_path
:
cache_file_list
)
{
auto
cache_file
=
fs
->
open_read
(
cache_path
,
_cache_load_converter
);
char
*
buffer
=
nullptr
;
size_t
buffer_size
=
0
;
ssize_t
line_len
=
0
;
while
((
line_len
=
getline
(
&
buffer
,
&
buffer_size
,
cache_file
.
get
()))
!=
-
1
)
{
if
(
line_len
<=
1
)
{
continue
;
}
char
*
data_ptr
;
cache_dict
.
append
(
strtoul
(
buffer
,
&
data_ptr
,
10
));
}
if
(
buffer
!=
nullptr
)
{
free
(
buffer
);
}
}
break
;
}
}
return
0
;
}
int
LearnerProcess
::
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
)
{
int
LearnerProcess
::
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
)
{
auto
fs
=
_context_ptr
->
file_system
;
auto
fs
=
_context_ptr
->
file_system
;
auto
*
ps_client
=
_context_ptr
->
pslib
->
ps_client
();
auto
*
ps_client
=
_context_ptr
->
pslib
->
ps_client
();
...
@@ -65,7 +114,38 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is
...
@@ -65,7 +114,38 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is
}
}
timer
.
Pause
();
timer
.
Pause
();
VLOG
(
2
)
<<
"Save Model Cost(s):"
<<
timer
.
ElapsedSec
();
VLOG
(
2
)
<<
"Save Model Cost(s):"
<<
timer
.
ElapsedSec
();
// save cache model, 只有inference需要cache_model
auto
*
ps_param
=
_context_ptr
->
pslib
->
get_param
();
if
(
_is_dump_cache_model
&&
(
way
==
ModelSaveWay
::
ModelSaveInferenceBase
||
way
==
ModelSaveWay
::
ModelSaveInferenceDelta
))
{
auto
&
table_param
=
ps_param
->
server_param
().
downpour_server_param
().
downpour_table_param
();
for
(
auto
&
param
:
table_param
)
{
if
(
param
.
type
()
!=
paddle
::
PS_SPARSE_TABLE
)
{
continue
;
}
double
cache_threshold
=
0.0
;
auto
status
=
ps_client
->
get_cache_threshold
(
param
.
table_id
(),
cache_threshold
);
CHECK
(
status
.
get
()
==
0
)
<<
"CacheThreshold Get failed!"
;
status
=
ps_client
->
cache_shuffle
(
param
.
table_id
(),
model_dir
,
std
::
to_string
((
int
)
way
),
std
::
to_string
(
cache_threshold
));
CHECK
(
status
.
get
()
==
0
)
<<
"Cache Shuffler Failed"
;
status
=
ps_client
->
save_cache
(
param
.
table_id
(),
model_dir
,
std
::
to_string
((
int
)
way
));
auto
feature_size
=
status
.
get
();
CHECK
(
feature_size
>=
0
)
<<
"Cache Save Failed"
;
auto
cache_model_path
=
fs
->
path_join
(
model_dir
,
string
::
format_string
(
"%03d_cache/sparse_cache.meta"
,
param
.
table_id
()));
auto
cache_meta_file
=
fs
->
open_write
(
cache_model_path
,
""
);
auto
meta
=
string
::
format_string
(
"file_prefix:part
\n
part_num:%d
\n
key_num:%d
\n
"
,
param
.
sparse_table_cache_file_num
(),
feature_size
);
CHECK
(
fwrite
(
meta
.
c_str
(),
meta
.
size
(),
1
,
cache_meta_file
.
get
())
==
1
)
<<
"Cache Meta Failed"
;
if
(
feature_size
>
_cache_sign_max_num
)
{
_cache_sign_max_num
=
feature_size
;
}
}
}
_context_ptr
->
epoch_accessor
->
update_model_donefile
(
epoch_id
,
way
);
_context_ptr
->
epoch_accessor
->
update_model_donefile
(
epoch_id
,
way
);
return
all_ret
;
return
all_ret
;
}
}
...
@@ -176,6 +256,9 @@ int LearnerProcess::run() {
...
@@ -176,6 +256,9 @@ int LearnerProcess::run() {
{
{
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
update_cache_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveInferenceBase
);
environment
->
barrier
(
EnvironmentRole
::
WORKER
);
if
(
epoch_accessor
->
is_last_epoch
(
epoch_id
))
{
if
(
epoch_accessor
->
is_last_epoch
(
epoch_id
))
{
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveTrainCheckpointBase
);
wait_save_model
(
epoch_id
,
ModelSaveWay
::
ModelSaveTrainCheckpointBase
);
}
else
{
}
else
{
...
...
paddle/fluid/train/custom_trainer/feed/process/learner_process.h
浏览文件 @
f5af6905
...
@@ -22,9 +22,13 @@ protected:
...
@@ -22,9 +22,13 @@ protected:
virtual
int
load_model
(
uint64_t
epoch_id
);
virtual
int
load_model
(
uint64_t
epoch_id
);
// 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型
// 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型
virtual
int
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
=
false
);
virtual
int
wait_save_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
,
bool
is_force_dump
=
false
);
virtual
int
update_cache_model
(
uint64_t
epoch_id
,
ModelSaveWay
way
);
private:
private:
bool
_startup_dump_inference_base
;
//启动立即dump base
bool
_is_dump_cache_model
;
// 是否进行cache dump
uint32_t
_cache_sign_max_num
=
0
;
// cache sign最大个数
std
::
string
_cache_load_converter
;
// cache加载的前置转换脚本
bool
_startup_dump_inference_base
;
// 启动立即dump base
std
::
vector
<
std
::
shared_ptr
<
MultiThreadExecutor
>>
_executors
;
std
::
vector
<
std
::
shared_ptr
<
MultiThreadExecutor
>>
_executors
;
};
};
...
...
paddle/fluid/train/custom_trainer/feed/trainer_context.h
浏览文件 @
f5af6905
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <memory>
#include <memory>
#include <vector>
#include <vector>
#include <sstream>
#include <sstream>
#include <tsl/bhopscotch_map.h>
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
...
@@ -35,15 +36,61 @@ enum class TrainerStatus {
...
@@ -35,15 +36,61 @@ enum class TrainerStatus {
Saving
=
1
// 模型存储状态
Saving
=
1
// 模型存储状态
};
};
const
uint32_t
SignCacheMaxValueNum
=
13
;
struct
SignCacheData
{
SignCacheData
()
{
memset
(
cache_value
,
0
,
sizeof
(
float
)
*
SignCacheMaxValueNum
);
}
uint32_t
idx
;
float
cache_value
[
SignCacheMaxValueNum
];
};
class
SignCacheDict
{
class
SignCacheDict
{
public:
public:
int32_t
sign2index
(
uint64_t
sign
)
{
inline
int32_t
sign2index
(
uint64_t
sign
)
{
auto
itr
=
_sign2data_map
.
find
(
sign
);
if
(
itr
==
_sign2data_map
.
end
())
{
return
-
1
;
return
-
1
;
}
}
return
itr
->
second
.
idx
;
}
uint64_t
index2sign
(
int32_t
index
)
{
inline
uint64_t
index2sign
(
int32_t
index
)
{
if
(
index
>=
_sign_list
.
size
())
{
return
0
;
return
0
;
}
}
return
_sign_list
[
index
];
}
inline
void
reserve
(
uint32_t
size
)
{
_sign_list
.
reserve
(
size
);
_sign2data_map
.
reserve
(
size
);
}
inline
void
clear
()
{
_sign_list
.
clear
();
_sign2data_map
.
clear
();
}
inline
void
append
(
uint64_t
sign
)
{
if
(
_sign2data_map
.
find
(
sign
)
!=
_sign2data_map
.
end
())
{
return
;
}
SignCacheData
data
;
data
.
idx
=
_sign_list
.
size
();
_sign_list
.
push_back
(
sign
);
_sign2data_map
.
emplace
(
sign
,
std
::
move
(
data
));
}
inline
SignCacheData
*
data
(
uint64_t
sign
)
{
tsl
::
bhopscotch_pg_map
<
uint64_t
,
SignCacheData
>::
iterator
itr
=
_sign2data_map
.
find
(
sign
);
if
(
itr
==
_sign2data_map
.
end
())
{
return
nullptr
;
}
return
const_cast
<
SignCacheData
*>
(
&
(
itr
->
second
));
}
private:
std
::
vector
<
uint64_t
>
_sign_list
;
tsl
::
bhopscotch_pg_map
<
uint64_t
,
SignCacheData
>
_sign2data_map
;
};
};
class
TrainerContext
{
class
TrainerContext
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录