Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ecfc7df9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ecfc7df9
编写于
3月 13, 2019
作者:
X
xujiaqi01
提交者:
dongdaxiang
3月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add dataset factory && fix style
上级
328f11b8
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
224 addition
and
60 deletion
+224
-60
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+2
-2
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+20
-19
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+32
-5
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+5
-1
paddle/fluid/framework/dataset_factory.cc
paddle/fluid/framework/dataset_factory.cc
+67
-0
paddle/fluid/framework/dataset_factory.h
paddle/fluid/framework/dataset_factory.h
+29
-0
paddle/fluid/framework/fleet/fleet_wrapper.cc
paddle/fluid/framework/fleet/fleet_wrapper.cc
+36
-14
paddle/fluid/framework/fleet/fleet_wrapper.h
paddle/fluid/framework/fleet/fleet_wrapper.h
+6
-4
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+1
-0
paddle/fluid/pybind/data_set_py.cc
paddle/fluid/pybind/data_set_py.cc
+13
-11
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+8
-2
python/paddle/fluid/trainer_desc.py
python/paddle/fluid/trainer_desc.py
+2
-2
python/paddle/fluid/trainer_factory.py
python/paddle/fluid/trainer_factory.py
+3
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
ecfc7df9
...
...
@@ -181,7 +181,7 @@ graph_to_program_pass variable_helper trainer_library data_feed_proto ${NGRAPH_E
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
cc_library
(
executor SRCS executor.cc multi_trainer.cc
cc_library
(
executor SRCS executor.cc multi_trainer.cc
dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
...
...
@@ -202,7 +202,7 @@ cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc
data_set.cc
dataset_factory.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
ecfc7df9
...
...
@@ -158,8 +158,6 @@ bool InMemoryDataFeed<T>::Start() {
DataFeed
::
CheckSetFileList
();
if
(
shuffled_ins_
->
Size
()
==
0
&&
shuffled_ins_out_
->
Size
()
==
0
)
{
FillMemoryDataToChannel
();
//std::unique_lock<std::mutex> lock(*mutex_for_update_memory_data_);
//std::vector<T>().swap(memory_data_);
}
DataFeed
::
finish_start_
=
true
;
return
true
;
...
...
@@ -227,13 +225,13 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
T
ins
;
DeserializeIns
(
ins
,
ins_str
);
DeserializeIns
(
&
ins
,
ins_str
);
shuffled_ins_
->
Push
(
std
::
move
(
ins
));
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
FillMemoryDataToChannel
()
{
VLOG
(
3
)
<<
"
InMemoryDataFeed<T>::
FillMemoryDataToChannel, thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"FillMemoryDataToChannel, thread_id="
<<
thread_id_
;
int64_t
start
=
0
;
int64_t
end
=
0
;
int64_t
size
=
memory_data_
->
size
();
...
...
@@ -252,7 +250,7 @@ void InMemoryDataFeed<T>::FillMemoryDataToChannel() {
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
FillChannelToMemoryData
()
{
VLOG
(
3
)
<<
"
InMemoryDataFeed<T>::
FillChannelToMemoryData, thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"FillChannelToMemoryData, thread_id="
<<
thread_id_
;
std
::
vector
<
T
>
local_vec
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
channel
=
nullptr
;
if
(
cur_channel_
==
0
)
{
...
...
@@ -274,11 +272,12 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() {
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
LoadIntoMemory
()
{
VLOG
(
3
)
<<
"
InMemoryDataFeed<T>::
LoadIntoMemory() begin, thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"LoadIntoMemory() begin, thread_id="
<<
thread_id_
;
std
::
vector
<
T
>
local_vec
;
std
::
string
filename
;
while
(
DataFeed
::
PickOneFile
(
&
filename
))
{
VLOG
(
3
)
<<
"PickOneFile, filename="
<<
filename
<<
", thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"PickOneFile, filename="
<<
filename
<<
", thread_id="
<<
thread_id_
;
int
err_no
=
0
;
PrivateQueueDataFeed
<
T
>::
fp_
=
fs_open_read
(
filename
,
&
err_no
,
PrivateQueueDataFeed
<
T
>::
pipe_command_
);
...
...
@@ -287,36 +286,38 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
while
(
ParseOneInstanceFromPipe
(
&
instance
))
{
local_vec
.
push_back
(
instance
);
}
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::LoadIntoMemory() read all lines, thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"LoadIntoMemory() read all lines, file="
<<
filename
<<
", thread_id="
<<
thread_id_
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
mutex_for_update_memory_data_
);
memory_data_
->
insert
(
memory_data_
->
end
(),
local_vec
.
begin
(),
local_vec
.
end
());
memory_data_
->
insert
(
memory_data_
->
end
(),
local_vec
.
begin
(),
local_vec
.
end
());
}
std
::
vector
<
T
>
().
swap
(
local_vec
);
}
VLOG
(
3
)
<<
"
InMemoryDataFeed<T>::
LoadIntoMemory() end, thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"LoadIntoMemory() end, thread_id="
<<
thread_id_
;
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
LocalShuffle
()
{
VLOG
(
3
)
<<
"
InMemoryDataFeed<T>::
LocalShuffle() begin, thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"LocalShuffle() begin, thread_id="
<<
thread_id_
;
FillMemoryDataToChannel
();
VLOG
(
3
)
<<
"
InMemoryDataFeed<T>::
LocalShuffle() end, thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"LocalShuffle() end, thread_id="
<<
thread_id_
;
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
GlobalShuffle
()
{
VLOG
(
3
)
<<
"GlobalShuffle(), thread_id="
<<
thread_id_
;
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
std
::
vector
<
std
::
string
>
send_str_vec
(
trainer_num_
);
for
(
int64_t
i
=
0
;
i
<
memory_data_
->
size
();
++
i
)
{
// todo get ins id
//std::string ins_id = memory_data_[i].ins_id;
//
std::string ins_id = memory_data_[i].ins_id;
// todo hash
//int64_t hash_id = paddle::ps::local_random_engine()();
int64_t
hash_id
=
0
;
int64_t
node_id
=
hash_id
%
trainer_num_
;
int64_t
random_num
=
fleet_ptr
->
local_random_engine
()();
int64_t
node_id
=
random_num
%
trainer_num_
;
std
::
string
str
;
SerializeIns
((
*
memory_data_
)[
i
],
str
);
SerializeIns
((
*
memory_data_
)[
i
],
&
str
);
send_str_vec
[
node_id
]
+=
str
;
if
(
i
%
fleet_send_batch_size_
==
0
&&
i
!=
0
)
{
for
(
int
j
=
0
;
j
<
send_str_vec
.
size
();
++
j
)
{
...
...
@@ -821,12 +822,12 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
// todo serialize ins in global shuffle
void
MultiSlotInMemoryDataFeed
::
SerializeIns
(
const
std
::
vector
<
MultiSlotType
>&
ins
,
std
::
string
&
str
)
{
const
std
::
vector
<
MultiSlotType
>&
ins
,
std
::
string
*
str
)
{
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
fleet_ptr
->
Serialize
(
ins
,
str
);
}
// todo deserialize ins in global shuffle
void
MultiSlotInMemoryDataFeed
::
DeserializeIns
(
std
::
vector
<
MultiSlotType
>
&
ins
,
void
MultiSlotInMemoryDataFeed
::
DeserializeIns
(
std
::
vector
<
MultiSlotType
>
*
ins
,
const
std
::
string
&
str
)
{
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
fleet_ptr
->
Deserialize
(
ins
,
str
);
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
ecfc7df9
...
...
@@ -212,13 +212,16 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual
void
LoadIntoMemory
();
virtual
void
LocalShuffle
();
virtual
void
GlobalShuffle
();
protected:
virtual
void
AddInstanceToInsVec
(
T
*
vec_ins
,
const
T
&
instance
,
int
index
)
=
0
;
virtual
void
AddInstanceToInsVec
(
T
*
vec_ins
,
const
T
&
instance
,
int
index
)
=
0
;
virtual
bool
ParseOneInstance
(
T
*
instance
)
=
0
;
virtual
bool
ParseOneInstanceFromPipe
(
T
*
instance
)
=
0
;
virtual
void
PutToFeedVec
(
const
T
&
ins_vec
)
=
0
;
virtual
void
SerializeIns
(
const
T
&
ins
,
std
::
string
&
str
)
=
0
;
virtual
void
DeserializeIns
(
T
&
ins
,
const
std
::
string
&
str
)
=
0
;
virtual
void
SerializeIns
(
const
T
&
ins
,
std
::
string
*
str
)
=
0
;
virtual
void
DeserializeIns
(
T
*
ins
,
const
std
::
string
&
str
)
=
0
;
int
thread_id_
;
int
thread_num_
;
...
...
@@ -284,6 +287,28 @@ class MultiSlotType {
const
std
::
string
&
GetType
()
const
{
return
type_
;
}
std
::
string
&
MutableType
()
{
return
type_
;
}
std
::
string
DebugString
()
{
std
::
stringstream
ss
;
ss
<<
"type: "
<<
type_
<<
"
\n
"
;
ss
<<
"offset:
\n
"
;
ss
<<
"["
;
for
(
const
size_t
&
i
:
offset_
)
{
ss
<<
offset_
[
i
]
<<
","
;
}
ss
<<
"]
\n
data:
\n
["
;
if
(
type_
[
0
]
==
'f'
)
{
for
(
const
float
&
i
:
float_feasign_
)
{
ss
<<
i
<<
","
;
}
}
else
{
for
(
const
uint64_t
&
i
:
uint64_feasign_
)
{
ss
<<
i
<<
","
;
}
}
ss
<<
"]
\n
"
;
return
ss
.
str
();
}
private:
void
CheckType
(
const
std
::
string
&
type
)
const
{
PADDLE_ENFORCE
((
type
==
"uint64"
)
||
(
type
==
"float"
),
...
...
@@ -336,8 +361,10 @@ class MultiSlotInMemoryDataFeed
virtual
bool
ParseOneInstance
(
std
::
vector
<
MultiSlotType
>*
instance
);
virtual
bool
ParseOneInstanceFromPipe
(
std
::
vector
<
MultiSlotType
>*
instance
);
virtual
void
PutToFeedVec
(
const
std
::
vector
<
MultiSlotType
>&
ins_vec
);
virtual
void
SerializeIns
(
const
std
::
vector
<
MultiSlotType
>&
ins
,
std
::
string
&
str
);
virtual
void
DeserializeIns
(
std
::
vector
<
MultiSlotType
>&
ins
,
const
std
::
string
&
str
);
virtual
void
SerializeIns
(
const
std
::
vector
<
MultiSlotType
>&
ins
,
std
::
string
*
str
);
virtual
void
DeserializeIns
(
std
::
vector
<
MultiSlotType
>*
ins
,
const
std
::
string
&
str
);
};
}
// namespace framework
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
ecfc7df9
...
...
@@ -54,7 +54,9 @@ void DatasetImpl<T>::SetThreadNum(int thread_num) {
}
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
}
void
DatasetImpl
<
T
>::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
}
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
)
{
...
...
@@ -115,10 +117,12 @@ void DatasetImpl<T>::GlobalShuffle() {
// if it is not InMemory, memory_data_ is empty
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
VLOG
(
3
)
<<
"registe_client2client_msg_handler"
;
fleet_ptr
->
registe_client2client_msg_handler
(
0
,
[
this
](
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
->
int
{
return
this
->
ReceiveFromClient
(
msg_type
,
client_id
,
msg
);
});
VLOG
(
3
)
<<
"start global shuffle threads"
;
std
::
vector
<
std
::
thread
>
global_shuffle_threads
;
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
global_shuffle_threads
.
push_back
(
...
...
paddle/fluid/framework/dataset_factory.cc
0 → 100644
浏览文件 @
ecfc7df9
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/dataset_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_set.h"
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
Dataset
>
(
*
CreateDatasetFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
CreateDatasetFunction
>
datasetMap
;
datasetMap
g_dataset_map
;
#define REGISTER_DATASET_CLASS(dataset_class) \
namespace { \
std::shared_ptr<Dataset> Creator_##dataset_class() { \
return std::shared_ptr<Dataset>(new dataset_class); \
} \
class __Registerer_##dataset_class { \
public: \
__Registerer_##dataset_class() { \
g_dataset_map[#dataset_class] = &Creator_##dataset_class; \
} \
}; \
__Registerer_##dataset_class g_registerer_##dataset_class; \
} // namespace
std
::
string
DatasetFactory
::
DatasetTypeList
()
{
std
::
string
dataset_types
;
for
(
auto
iter
=
g_dataset_map
.
begin
();
iter
!=
g_dataset_map
.
end
();
++
iter
)
{
if
(
iter
!=
g_dataset_map
.
begin
())
{
dataset_types
+=
", "
;
}
dataset_types
+=
iter
->
first
;
}
return
dataset_types
;
}
std
::
shared_ptr
<
Dataset
>
DatasetFactory
::
CreateDataset
(
std
::
string
dataset_class
)
{
if
(
g_dataset_map
.
count
(
dataset_class
)
<
1
)
{
LOG
(
WARNING
)
<<
"Your Dataset "
<<
dataset_class
<<
"is not supported currently"
;
LOG
(
WARNING
)
<<
"Supported Dataset: "
<<
DatasetTypeList
();
exit
(
-
1
);
}
return
g_dataset_map
[
dataset_class
]();
}
REGISTER_DATASET_CLASS
(
MultiSlotDataset
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/dataset_factory.h
0 → 100644
浏览文件 @
ecfc7df9
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/data_set.h"
namespace
paddle
{
namespace
framework
{
class
DatasetFactory
{
public:
static
std
::
string
DatasetTypeList
();
static
std
::
shared_ptr
<
Dataset
>
CreateDataset
(
std
::
string
dataset_class
);
};
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/fleet/fleet_wrapper.cc
浏览文件 @
ecfc7df9
...
...
@@ -27,6 +27,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <utility>
#include "paddle/fluid/framework/data_feed.h"
namespace
paddle
{
...
...
@@ -45,7 +46,7 @@ paddle::ps::Archive<AR>& operator << (
ar
<<
ins
.
GetOffset
();
ar
<<
ins
.
GetFloatData
();
ar
<<
ins
.
GetUint64Data
();
return
ar
;
return
ar
;
}
template
<
class
AR
>
...
...
@@ -56,7 +57,7 @@ paddle::ps::Archive<AR>& operator >> (
ar
>>
ins
.
MutableOffset
();
ar
>>
ins
.
MutableFloatData
();
ar
>>
ins
.
MutableUint64Data
();
return
ar
;
return
ar
;
}
#endif
...
...
@@ -291,42 +292,63 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
}
// todo registe_client2client_msg_handler
int
FleetWrapper
::
registe_client2client_msg_handler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
return
0
;
int
FleetWrapper
::
registe_client2client_msg_handler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
pslib_ptr_
->
_worker_ptr
->
registe_client2client_msg_handler
(
msg_type
,
handler
);
return
0
;
}
// todo send_client2client_msg
int
FleetWrapper
::
send_client2client_msg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
return
0
;
int
FleetWrapper
::
send_client2client_msg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
pslib_ptr_
->
_worker_ptr
->
send_client2client_msg
(
msg_type
,
to_client_id
,
msg
);
return
0
;
}
std
::
default_random_engine
&
FleetWrapper
::
local_random_engine
()
{
struct
engine_wrapper_t
{
std
::
default_random_engine
engine
;
engine_wrapper_t
()
{
struct
timespec
tp
;
clock_gettime
(
CLOCK_REALTIME
,
&
tp
);
double
cur_time
=
tp
.
tv_sec
+
tp
.
tv_nsec
*
1e-9
;
static
std
::
atomic
<
uint64_t
>
x
(
0
);
std
::
seed_seq
sseq
=
{
x
++
,
x
++
,
x
++
,
(
uint64_t
)(
cur_time
*
1000
)};
engine
.
seed
(
sseq
);
}
};
thread_local
engine_wrapper_t
r
;
return
r
.
engine
;
}
template
<
typename
T
>
void
FleetWrapper
::
Serialize
(
const
T
&
t
,
std
::
string
&
str
)
{
void
FleetWrapper
::
Serialize
(
const
T
&
t
,
std
::
string
*
str
)
{
#ifdef PADDLE_WITH_PSLIB
paddle
::
ps
::
BinaryArchive
ar
;
ar
<<
t
;
str
=
std
::
string
(
ar
.
buffer
(),
ar
.
length
());
*
str
=
std
::
string
(
ar
.
buffer
(),
ar
.
length
());
#else
VLOG
(
0
)
<<
"FleetWrapper::Serialize do nothing when no pslib"
;
#endif
}
template
<
typename
T
>
void
FleetWrapper
::
Deserialize
(
T
&
t
,
const
std
::
string
&
str
)
{
void
FleetWrapper
::
Deserialize
(
T
*
t
,
const
std
::
string
&
str
)
{
#ifdef PADDLE_WITH_PSLIB
paddle
::
ps
::
BinaryArchive
ar
;
ar
.
set_read_buffer
(
const_cast
<
char
*>
(
str
.
c_str
()),
str
.
length
(),
nullptr
);
t
=
ar
.
get
<
T
>
();
*
t
=
ar
.
get
<
T
>
();
#else
VLOG
(
0
)
<<
"FleetWrapper::Deserialize do nothing when no pslib"
;
#endif
}
template
void
FleetWrapper
::
Serialize
<
std
::
vector
<
MultiSlotType
>
>
(
const
std
::
vector
<
MultiSlotType
>&
,
std
::
string
&
);
const
std
::
vector
<
MultiSlotType
>&
,
std
::
string
*
);
template
void
FleetWrapper
::
Deserialize
(
std
::
vector
<
MultiSlotType
>
&
,
const
std
::
string
&
);
std
::
vector
<
MultiSlotType
>
*
,
const
std
::
string
&
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/fleet/fleet_wrapper.h
浏览文件 @
ecfc7df9
...
...
@@ -21,7 +21,7 @@ limitations under the License. */
#endif
#include <random>
#include <atomic>
#include <
time.h
>
#include <
ctime
>
#include <string>
#include <vector>
#include "paddle/fluid/framework/scope.h"
...
...
@@ -116,13 +116,15 @@ class FleetWrapper {
typedef
std
::
function
<
int32_t
(
int
,
int
,
const
std
::
string
&
)
>
MsgHandlerFunc
;
int
registe_client2client_msg_handler
(
int
msg_type
,
MsgHandlerFunc
handler
);
int
send_client2client_msg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
);
int
send_client2client_msg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
);
std
::
default_random_engine
&
local_random_engine
();
template
<
typename
T
>
void
Serialize
(
const
T
&
t
,
std
::
string
&
str
);
void
Serialize
(
const
T
&
t
,
std
::
string
*
str
);
template
<
typename
T
>
void
Deserialize
(
T
&
t
,
const
std
::
string
&
str
);
void
Deserialize
(
T
*
t
,
const
std
::
string
&
str
);
static
std
::
shared_ptr
<
FleetWrapper
>
GetInstance
()
{
if
(
NULL
==
s_instance_
)
{
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
ecfc7df9
...
...
@@ -65,6 +65,7 @@ void MultiTrainer::Finalize() {
for
(
auto
&
th
:
threads_
)
{
th
.
join
();
}
// todo dataset->DestroyReaders();
}
}
// end namespace framework
...
...
paddle/fluid/pybind/data_set_py.cc
浏览文件 @
ecfc7df9
...
...
@@ -21,7 +21,7 @@ limitations under the License. */
#endif
#include <string>
#include <vector>
#include <memory>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/async_executor.h"
...
...
@@ -33,6 +33,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/pybind/data_set_py.h"
#include "paddle/fluid/framework/dataset_factory.h"
namespace
py
=
pybind11
;
namespace
pd
=
paddle
::
framework
;
...
...
@@ -41,17 +42,18 @@ namespace paddle {
namespace
pybind
{
void
BindDataset
(
py
::
module
*
m
)
{
py
::
class_
<
framework
::
MultiSlotDataset
>
(
*
m
,
"MultiSlotDataset"
)
.
def
(
py
::
init
([]()
{
return
std
::
unique_ptr
<
framework
::
MultiSlotDataset
>
(
new
framework
::
MultiSlotDataset
());
py
::
class_
<
framework
::
Dataset
,
std
::
shared_ptr
<
framework
::
Dataset
>>
(
*
m
,
"Dataset"
)
.
def
(
py
::
init
([](
const
std
::
string
&
name
=
"MultiSlotDataset"
)
{
return
framework
::
DatasetFactory
::
CreateDataset
(
name
);
}))
.
def
(
"set_filelist"
,
&
framework
::
MultiSlot
Dataset
::
SetFileList
)
.
def
(
"set_thread_num"
,
&
framework
::
MultiSlot
Dataset
::
SetThreadNum
)
.
def
(
"set_trainer_num"
,
&
framework
::
MultiSlot
Dataset
::
SetTrainerNum
)
.
def
(
"set_data_feed_desc"
,
&
framework
::
MultiSlot
Dataset
::
SetDataFeedDesc
)
.
def
(
"load_into_memory"
,
&
framework
::
MultiSlot
Dataset
::
LoadIntoMemory
)
.
def
(
"local_shuffle"
,
&
framework
::
MultiSlot
Dataset
::
LocalShuffle
)
.
def
(
"global_shuffle"
,
&
framework
::
MultiSlot
Dataset
::
GlobalShuffle
);
.
def
(
"set_filelist"
,
&
framework
::
Dataset
::
SetFileList
)
.
def
(
"set_thread_num"
,
&
framework
::
Dataset
::
SetThreadNum
)
.
def
(
"set_trainer_num"
,
&
framework
::
Dataset
::
SetTrainerNum
)
.
def
(
"set_data_feed_desc"
,
&
framework
::
Dataset
::
SetDataFeedDesc
)
.
def
(
"load_into_memory"
,
&
framework
::
Dataset
::
LoadIntoMemory
)
.
def
(
"local_shuffle"
,
&
framework
::
Dataset
::
LocalShuffle
)
.
def
(
"global_shuffle"
,
&
framework
::
Dataset
::
GlobalShuffle
);
}
}
// end namespace pybind
...
...
python/paddle/fluid/dataset.py
浏览文件 @
ecfc7df9
...
...
@@ -37,7 +37,7 @@ class DatasetBase(object):
# to decide whether we need create in memory instance
self
.
proto_desc
=
data_feed_pb2
.
DataFeedDesc
()
self
.
proto_desc
.
pipe_command
=
"cat"
self
.
dataset
=
core
.
MultiSlotDataset
(
)
self
.
dataset
=
core
.
Dataset
(
"MultiSlotDataset"
)
self
.
thread_num
=
0
def
set_pipe_command
(
self
,
pipe_command
):
...
...
@@ -119,10 +119,16 @@ class InMemoryDataset(DatasetBase):
from
.distributed
import
ps_instance
instance
=
ps_instance
.
PaddlePSInstance
(
1
,
2
)
self
.
dataset
.
set_trainer_num
(
instance
.
get_worker_num
())
self
.
global_shuffle
()
self
.
dataset
.
global_shuffle
()
class
QueueDataset
(
DatasetBase
):
def
__init__
(
self
):
super
(
QueueDataset
,
self
).
__init__
()
self
.
proto_desc
.
name
=
"MultiSlotDataFeed"
def
local_shuffle
(
self
):
pass
def
global_shuffle
(
self
):
pass
python/paddle/fluid/trainer_desc.py
浏览文件 @
ecfc7df9
...
...
@@ -20,7 +20,7 @@ from google.protobuf import text_format
__all__
=
[
'TrainerDesc'
,
'MultiTrainer'
,
'DistMultiTrainer'
]
# can be initialized from train_desc,
# can be initialized from train_desc,
class
TrainerDesc
(
object
):
def
__init__
(
self
):
'''
...
...
@@ -59,7 +59,7 @@ class MultiTrainer(TrainerDesc):
def
gen_trainer_desc
(
self
):
super
(
MultiTrainer
,
self
).
gen_trainer_desc
()
self
.
proto_desc
.
class_name
=
"MultiTrainer"
self
.
device_worker_
.
gen_worker_desc
(
self
.
proto_desc
,
fleet_desc_
)
self
.
device_worker_
.
gen_worker_desc
(
self
.
proto_desc
,
self
.
fleet_desc_
)
class
DistMultiTrainer
(
TrainerDesc
):
...
...
python/paddle/fluid/trainer_factory.py
浏览文件 @
ecfc7df9
...
...
@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
trainer_desc
import
*
from
device_worker
import
*
__all__
=
[
"TrainerFactory"
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录