Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6a57e807
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6a57e807
编写于
4月 04, 2019
作者:
X
xjqbest
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
remove trainer_id in datafeed and dataset
test=develop
上级
7a759d76
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
2 addition
and
50 deletion
+2
-50
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+2
-23
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+0
-4
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+0
-11
paddle/fluid/framework/data_set.h
paddle/fluid/framework/data_set.h
+0
-7
paddle/fluid/pybind/data_set_py.cc
paddle/fluid/pybind/data_set_py.cc
+0
-2
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+0
-3
未找到文件。
paddle/fluid/framework/data_feed.cc
浏览文件 @
6a57e807
...
...
@@ -237,11 +237,6 @@ void InMemoryDataFeed<T>::SetThreadNum(int thread_num) {
thread_num_
=
thread_num
;
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
SetTrainerId
(
int
trainer_id
)
{
trainer_id_
=
trainer_id
;
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
...
...
@@ -372,12 +367,10 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
std
::
vector
<
std
::
vector
<
T
*>>
send_vec
(
trainer_num_
);
std
::
vector
<
int
>
send_index
(
trainer_num_
);
std
::
vector
<
T
>
local_send_vec
;
uint64_t
reserve_len
=
fleet_send_batch_size_
/
trainer_num_
;
for
(
auto
&
vec
:
send_vec
)
{
vec
.
reserve
(
reserve_len
);
}
local_send_vec
.
reserve
(
reserve_len
);
for
(
int
i
=
0
;
i
<
trainer_num_
;
++
i
)
{
send_index
[
i
]
=
i
;
}
...
...
@@ -390,23 +383,12 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
// std::string ins_id = memory_data_[i].ins_id;
int64_t
random_num
=
rand_r
(
&
rand_seed
);
int64_t
node_id
=
random_num
%
trainer_num_
;
if
(
node_id
==
trainer_id_
)
{
local_send_vec
.
push_back
((
*
memory_data_
)[
i
]);
}
else
{
send_vec
[
node_id
].
push_back
(
&
((
*
memory_data_
)[
i
]));
}
send_vec
[
node_id
].
push_back
(
&
((
*
memory_data_
)[
i
]));
if
(
i
%
fleet_send_batch_size_
==
0
&&
i
!=
0
)
{
// shuffle the sequence of sending to avoid network timeout error
std
::
random_shuffle
(
send_index
.
begin
(),
send_index
.
end
());
for
(
int
index
=
0
;
index
<
send_index
.
size
();
++
index
)
{
int
j
=
send_index
[
index
];
if
(
j
==
trainer_id_
)
{
VLOG
(
3
)
<<
"send to local, ins num="
<<
local_send_vec
.
size
()
<<
", node_id="
<<
j
<<
", thread_id="
<<
thread_id_
;
shuffled_ins_
->
Extend
(
std
::
move
(
local_send_vec
));
local_send_vec
.
clear
();
continue
;
}
std
::
string
send_str
;
SerializeIns
(
send_vec
[
j
],
&
send_str
);
VLOG
(
3
)
<<
"send str_length="
<<
send_str
.
length
()
...
...
@@ -423,10 +405,7 @@ void InMemoryDataFeed<T>::GlobalShuffle() {
std
::
random_shuffle
(
send_index
.
begin
(),
send_index
.
end
());
for
(
int
index
=
0
;
index
<
send_index
.
size
();
++
index
)
{
int
j
=
send_index
[
index
];
if
(
j
==
trainer_id_
&&
local_send_vec
.
size
()
!=
0
)
{
shuffled_ins_
->
Extend
(
std
::
move
(
local_send_vec
));
std
::
vector
<
T
>
().
swap
(
local_send_vec
);
}
else
if
(
send_vec
[
j
].
size
()
!=
0
)
{
if
(
send_vec
[
j
].
size
()
!=
0
)
{
std
::
string
send_str
;
SerializeIns
(
send_vec
[
j
],
&
send_str
);
VLOG
(
3
)
<<
"send str_length="
<<
send_str
.
length
()
<<
" to node_id="
<<
j
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
6a57e807
...
...
@@ -91,8 +91,6 @@ class DataFeed {
// This function will do nothing at default
virtual
void
SetThreadId
(
int
thread_id
)
{}
// This function will do nothing at default
virtual
void
SetTrainerId
(
int
trainer_id
)
{}
// This function will do nothing at default
virtual
void
SetThreadNum
(
int
thread_num
)
{}
// This function will do nothing at default
virtual
void
SetTrainerNum
(
int
trainer_num
)
{}
...
...
@@ -215,7 +213,6 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual
void
SetMemoryDataMutex
(
std
::
mutex
*
mutex
);
virtual
void
SetThreadId
(
int
thread_id
);
virtual
void
SetThreadNum
(
int
thread_num
);
virtual
void
SetTrainerId
(
int
trainer_id
);
virtual
void
SetTrainerNum
(
int
trainer_num
);
virtual
void
SetFleetSendBatchSize
(
int64_t
size
);
virtual
void
PutInsToChannel
(
const
std
::
string
&
ins_str
);
...
...
@@ -237,7 +234,6 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
int
thread_id_
;
int
thread_num_
;
int
trainer_id_
;
int
trainer_num_
;
uint32_t
rand_seed
;
std
::
vector
<
T
>*
memory_data_
;
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
6a57e807
...
...
@@ -52,17 +52,6 @@ void DatasetImpl<T>::SetThreadNum(int thread_num) {
thread_num_
=
thread_num
;
}
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetTrainerId
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetTrainerId
(
int
trainer_id
)
{
trainer_id_
=
trainer_id
;
for
(
auto
reader
:
readers_
)
{
reader
->
SetTrainerId
(
trainer_id
);
}
}
// if you run distributed, and want to do global shuffle,
// set this before global shuffle.
// be sure you call CreateReaders before SetTrainerNum
...
...
paddle/fluid/framework/data_set.h
浏览文件 @
6a57e807
...
...
@@ -45,8 +45,6 @@ class Dataset {
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
=
0
;
// set readers' num
virtual
void
SetThreadNum
(
int
thread_num
)
=
0
;
// set worker rank
virtual
void
SetTrainerId
(
int
trainer_id
)
=
0
;
// set workers' num
virtual
void
SetTrainerNum
(
int
trainer_num
)
=
0
;
// set fleet send batch size
...
...
@@ -61,8 +59,6 @@ class Dataset {
virtual
const
std
::
vector
<
std
::
string
>&
GetFileList
()
=
0
;
// get thread num
virtual
int
GetThreadNum
()
=
0
;
// get worker rank
virtual
int
GetTrainerId
()
=
0
;
// get worker num
virtual
int
GetTrainerNum
()
=
0
;
// get fleet send batch size
...
...
@@ -105,7 +101,6 @@ class DatasetImpl : public Dataset {
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
);
virtual
void
SetThreadNum
(
int
thread_num
);
virtual
void
SetTrainerId
(
int
trainer_id
);
virtual
void
SetTrainerNum
(
int
trainer_num
);
virtual
void
SetFleetSendBatchSize
(
int64_t
size
);
virtual
void
SetHdfsConfig
(
const
std
::
string
&
fs_name
,
...
...
@@ -114,7 +109,6 @@ class DatasetImpl : public Dataset {
virtual
const
std
::
vector
<
std
::
string
>&
GetFileList
()
{
return
filelist_
;
}
virtual
int
GetThreadNum
()
{
return
thread_num_
;
}
virtual
int
GetTrainerId
()
{
return
trainer_id_
;
}
virtual
int
GetTrainerNum
()
{
return
trainer_num_
;
}
virtual
int64_t
GetFleetSendBatchSize
()
{
return
fleet_send_batch_size_
;
}
virtual
std
::
pair
<
std
::
string
,
std
::
string
>
GetHdfsConfig
()
{
...
...
@@ -142,7 +136,6 @@ class DatasetImpl : public Dataset {
std
::
mutex
mutex_for_update_memory_data_
;
int
thread_num_
;
paddle
::
framework
::
DataFeedDesc
data_feed_desc_
;
int
trainer_id_
;
int
trainer_num_
;
std
::
vector
<
std
::
string
>
filelist_
;
size_t
file_idx_
;
...
...
paddle/fluid/pybind/data_set_py.cc
浏览文件 @
6a57e807
...
...
@@ -49,7 +49,6 @@ void BindDataset(py::module* m) {
}))
.
def
(
"set_filelist"
,
&
framework
::
Dataset
::
SetFileList
)
.
def
(
"set_thread_num"
,
&
framework
::
Dataset
::
SetThreadNum
)
.
def
(
"set_trainer_id"
,
&
framework
::
Dataset
::
SetTrainerId
)
.
def
(
"set_trainer_num"
,
&
framework
::
Dataset
::
SetTrainerNum
)
.
def
(
"set_fleet_send_batch_size"
,
&
framework
::
Dataset
::
SetFleetSendBatchSize
)
...
...
@@ -57,7 +56,6 @@ void BindDataset(py::module* m) {
.
def
(
"set_data_feed_desc"
,
&
framework
::
Dataset
::
SetDataFeedDesc
)
.
def
(
"get_filelist"
,
&
framework
::
Dataset
::
GetFileList
)
.
def
(
"get_thread_num"
,
&
framework
::
Dataset
::
GetThreadNum
)
.
def
(
"get_trainer_id"
,
&
framework
::
Dataset
::
GetTrainerId
)
.
def
(
"get_trainer_num"
,
&
framework
::
Dataset
::
GetTrainerNum
)
.
def
(
"get_fleet_send_batch_size"
,
&
framework
::
Dataset
::
GetFleetSendBatchSize
)
...
...
python/paddle/fluid/dataset.py
浏览文件 @
6a57e807
...
...
@@ -240,15 +240,12 @@ class InMemoryDataset(DatasetBase):
Args:
fleet: fleet singleton. Default None.
"""
trainer_id
=
0
trainer_num
=
1
fleet_send_batch_size
=
80000
if
fleet
is
not
None
:
fleet
.
fleet_instance
.
role_maker_
.
_barrier_worker
()
trainer_id
=
fleet
.
worker_index
()
trainer_num
=
fleet
.
worker_num
()
self
.
dataset
.
register_client2client_msg_handler
()
self
.
dataset
.
set_trainer_id
(
trainer_id
)
self
.
dataset
.
set_trainer_num
(
trainer_num
)
self
.
dataset
.
set_fleet_send_batch_size
(
fleet_send_batch_size
)
if
fleet
is
not
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录