Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
39449ba0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
39449ba0
编写于
3月 13, 2019
作者:
X
xujiaqi01
提交者:
dongdaxiang
3月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug && add DestroyReaders in trainer
上级
3641a78b
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
31 addition
and
16 deletion
+31
-16
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+3
-3
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+2
-2
paddle/fluid/framework/dist_multi_trainer.cc
paddle/fluid/framework/dist_multi_trainer.cc
+2
-0
paddle/fluid/framework/fleet/fleet_wrapper.cc
paddle/fluid/framework/fleet/fleet_wrapper.cc
+15
-5
paddle/fluid/framework/fleet/fleet_wrapper.h
paddle/fluid/framework/fleet/fleet_wrapper.h
+5
-5
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+2
-1
paddle/fluid/framework/trainer.h
paddle/fluid/framework/trainer.h
+2
-0
未找到文件。
paddle/fluid/framework/data_feed.cc
浏览文件 @
39449ba0
...
...
@@ -314,21 +314,21 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
// todo get ins id
// std::string ins_id = memory_data_[i].ins_id;
// todo hash
int64_t
random_num
=
fleet_ptr
->
local_random_e
ngine
()();
int64_t
random_num
=
fleet_ptr
->
LocalRandomE
ngine
()();
int64_t
node_id
=
random_num
%
trainer_num_
;
std
::
string
str
;
SerializeIns
((
*
memory_data_
)[
i
],
&
str
);
send_str_vec
[
node_id
]
+=
str
;
if
(
i
%
fleet_send_batch_size_
==
0
&&
i
!=
0
)
{
for
(
int
j
=
0
;
j
<
send_str_vec
.
size
();
++
j
)
{
fleet_ptr
->
send_client2client_m
sg
(
0
,
j
,
send_str_vec
[
j
]);
fleet_ptr
->
SendClientToClientM
sg
(
0
,
j
,
send_str_vec
[
j
]);
send_str_vec
[
j
]
=
""
;
}
}
}
for
(
int
j
=
0
;
j
<
send_str_vec
.
size
();
++
j
)
{
if
(
send_str_vec
[
j
].
length
()
!=
0
)
{
fleet_ptr
->
send_client2client_m
sg
(
0
,
j
,
send_str_vec
[
j
]);
fleet_ptr
->
SendClientToClientM
sg
(
0
,
j
,
send_str_vec
[
j
]);
}
}
}
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
39449ba0
...
...
@@ -117,8 +117,8 @@ void DatasetImpl<T>::GlobalShuffle() {
// if it is not InMemory, memory_data_ is empty
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
VLOG
(
3
)
<<
"
registe_client2client_msg_h
andler"
;
fleet_ptr
->
registe_client2client_msg_h
andler
(
0
,
VLOG
(
3
)
<<
"
RegisterClientToClientMsgH
andler"
;
fleet_ptr
->
RegisterClientToClientMsgH
andler
(
0
,
[
this
](
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
->
int
{
return
this
->
ReceiveFromClient
(
msg_type
,
client_id
,
msg
);
});
...
...
paddle/fluid/framework/dist_multi_trainer.cc
浏览文件 @
39449ba0
...
...
@@ -25,6 +25,7 @@ namespace framework {
void
DistMultiTrainer
::
Initialize
(
const
TrainerDesc
&
trainer_desc
,
Dataset
*
dataset
)
{
thread_num_
=
trainer_desc
.
thread_num
();
SetDataset
(
dataset
);
workers_
.
resize
(
thread_num_
);
dataset
->
CreateReaders
();
...
...
@@ -55,6 +56,7 @@ void DistMultiTrainer::Finalize() {
th
.
join
();
}
pull_dense_worker_
->
Stop
();
dataset_ptr_
->
DestroyReaders
();
}
}
// end namespace framework
...
...
paddle/fluid/framework/fleet/fleet_wrapper.cc
浏览文件 @
39449ba0
...
...
@@ -292,21 +292,31 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
}
int
FleetWrapper
::
registe_client2client_msg_h
andler
(
int
FleetWrapper
::
RegisterClientToClientMsgH
andler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
#ifdef PADDLE_WITH_PSLIB
pslib_ptr_
->
_worker_ptr
->
registe_client2client_msg_handler
(
msg_type
,
handler
);
#else
VLOG
(
0
)
<<
"FleetWrapper::RegisterClientToClientMsgHandler"
<<
" does nothing when no pslib"
;
#endif
return
0
;
}
int
FleetWrapper
::
send_client2client_m
sg
(
int
FleetWrapper
::
SendClientToClientM
sg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
#ifdef PADDLE_WITH_PSLIB
pslib_ptr_
->
_worker_ptr
->
send_client2client_msg
(
msg_type
,
to_client_id
,
msg
);
#else
VLOG
(
0
)
<<
"FleetWrapper::SendClientToClientMsg"
<<
" does nothing when no pslib"
;
#endif
return
0
;
}
std
::
default_random_engine
&
FleetWrapper
::
local_random_e
ngine
()
{
std
::
default_random_engine
&
FleetWrapper
::
LocalRandomE
ngine
()
{
struct
engine_wrapper_t
{
std
::
default_random_engine
engine
;
engine_wrapper_t
()
{
...
...
@@ -330,7 +340,7 @@ void FleetWrapper::Serialize(const T& t, std::string* str) {
ar
<<
t
;
*
str
=
std
::
string
(
ar
.
buffer
(),
ar
.
length
());
#else
VLOG
(
0
)
<<
"FleetWrapper::Serialize do nothing when no pslib"
;
VLOG
(
0
)
<<
"FleetWrapper::Serialize do
es
nothing when no pslib"
;
#endif
}
...
...
@@ -341,7 +351,7 @@ void FleetWrapper::Deserialize(T* t, const std::string& str) {
ar
.
set_read_buffer
(
const_cast
<
char
*>
(
str
.
c_str
()),
str
.
length
(),
nullptr
);
*
t
=
ar
.
get
<
T
>
();
#else
VLOG
(
0
)
<<
"FleetWrapper::Deserialize do nothing when no pslib"
;
VLOG
(
0
)
<<
"FleetWrapper::Deserialize do
es
nothing when no pslib"
;
#endif
}
...
...
paddle/fluid/framework/fleet/fleet_wrapper.h
浏览文件 @
39449ba0
...
...
@@ -115,11 +115,11 @@ class FleetWrapper {
void
GatherServers
(
const
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
);
typedef
std
::
function
<
int32_t
(
int
,
int
,
const
std
::
string
&
)
>
MsgHandlerFunc
;
int
registe_client2client_msg_h
andler
(
int
msg_type
,
MsgHandlerFunc
handler
);
int
send_client2client_m
sg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
);
std
::
default_random_engine
&
local_random_e
ngine
();
int
RegisterClientToClientMsgH
andler
(
int
msg_type
,
MsgHandlerFunc
handler
);
int
SendClientToClientM
sg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
);
std
::
default_random_engine
&
LocalRandomE
ngine
();
template
<
typename
T
>
void
Serialize
(
const
T
&
t
,
std
::
string
*
str
);
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
39449ba0
...
...
@@ -24,6 +24,7 @@ namespace framework {
void
MultiTrainer
::
Initialize
(
const
TrainerDesc
&
trainer_desc
,
Dataset
*
dataset
)
{
thread_num_
=
trainer_desc
.
thread_num
();
SetDataset
(
dataset
);
// get filelist from trainer_desc here
workers_
.
resize
(
thread_num_
);
VLOG
(
3
)
<<
"worker thread num: "
<<
thread_num_
;
...
...
@@ -65,7 +66,7 @@ void MultiTrainer::Finalize() {
for
(
auto
&
th
:
threads_
)
{
th
.
join
();
}
// todo dataset
->DestroyReaders();
dataset_ptr_
->
DestroyReaders
();
}
}
// end namespace framework
...
...
paddle/fluid/framework/trainer.h
浏览文件 @
39449ba0
...
...
@@ -41,6 +41,7 @@ class TrainerBase {
// model memory are hosted in root_scope
void
SetScope
(
Scope
*
root_scope
);
void
SetDebug
(
const
bool
debug
)
{
debug_
=
debug
;
}
void
SetDataset
(
Dataset
*
dataset_ptr
)
{
dataset_ptr_
=
dataset_ptr
;
}
virtual
void
Initialize
(
const
TrainerDesc
&
trainer_desc
,
Dataset
*
data_set
)
=
0
;
virtual
void
InitTrainerEnv
(
const
ProgramDesc
&
main_program
,
...
...
@@ -52,6 +53,7 @@ class TrainerBase {
protected:
Scope
*
root_scope_
;
bool
debug_
;
Dataset
*
dataset_ptr_
;
};
// general trainer for async execution
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录