Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9bca1926
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9bca1926
编写于
3月 08, 2019
作者:
H
heqiaozhi
提交者:
dongdaxiang
3月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor & fix bug
上级
2e9a836c
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
116 addition
and
152 deletion
+116
-152
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+14
-17
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+0
-103
paddle/fluid/framework/device_worker.h
paddle/fluid/framework/device_worker.h
+1
-0
paddle/fluid/framework/downpour_worker.cc
paddle/fluid/framework/downpour_worker.cc
+54
-20
paddle/fluid/framework/pull_dense_worker.cc
paddle/fluid/framework/pull_dense_worker.cc
+18
-6
paddle/fluid/framework/trainer_desc.proto
paddle/fluid/framework/trainer_desc.proto
+9
-0
paddle/fluid/pybind/data_set_py.cc
paddle/fluid/pybind/data_set_py.cc
+2
-4
python/paddle/fluid/async_executor.py
python/paddle/fluid/async_executor.py
+3
-2
python/paddle/fluid/trainer_desc.py
python/paddle/fluid/trainer_desc.py
+15
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
9bca1926
...
...
@@ -29,6 +29,7 @@ add_subdirectory(io)
#ddim lib
proto_library
(
framework_proto SRCS framework.proto
)
proto_library
(
data_feed_proto SRCS data_feed.proto
)
proto_library
(
async_executor_proto SRCS data_feed.proto
)
proto_library
(
trainer_desc_proto SRCS trainer_desc.proto
)
cc_library
(
ddim SRCS ddim.cc DEPS eigen3 boost enforce
)
...
...
@@ -174,12 +175,19 @@ endif()
cc_library
(
executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector
)
if
(
WITH_DISTRIBUTE
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method sendrecvop_rpc
${
GLOB_DISTRIBUTE_DEPS
}
graph_to_program_pass variable_helper
${
NGRAPH_EXE_DEPS
}
trainer_library
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog fleet_wrapper
lod_rank_table feed_fetch_method sendrecvop_rpc
${
GLOB_DISTRIBUTE_DEPS
}
graph_to_program_pass variable_helper trainer_library data_feed_proto
${
NGRAPH_EXE_DEPS
}
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
cc_library
(
executor SRCS executor.cc multi_trainer.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto data_feed_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper
${
NGRAPH_EXE_DEPS
}
timer
)
cc_library
(
executor SRCS executor.cc multi_trainer.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
data_feed.cc device_worker.cc hogwild_worker.cc downpour_worker.cc
pull_dense_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper
${
NGRAPH_EXE_DEPS
}
timer data_feed_proto
)
cc_test
(
test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op
)
endif
()
...
...
@@ -190,8 +198,6 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
graph build_strategy
fast_threaded_ssa_graph_executor variable_helper
)
<<<<<<< HEAD
=======
if
(
WITH_PSLIB
)
cc_library
(
async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
...
...
@@ -201,7 +207,7 @@ if(WITH_PSLIB)
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper pslib_brpc pslib timer
)
variable_helper pslib_brpc pslib timer
fs shell
)
else
()
cc_library
(
async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
...
...
@@ -211,18 +217,9 @@ else()
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper timer
)
variable_helper timer
fs shell
)
endif
(
WITH_PSLIB
)
>>>>>>> 870b88bbd7... add DataSet and InMemoryDataFeed, support load data into memory and shuffle data
cc_library
(
async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
data_set.cc DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table fleet_wrapper lodtensor_printer
feed_fetch_method graph_to_program_pass data_feed_proto
variable_helper timer
)
cc_test
(
data_feed_test SRCS data_feed_test.cc DEPS async_executor
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
9bca1926
...
...
@@ -220,111 +220,8 @@ void InMemoryDataFeed<T>::LocalShuffle() {
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
}
// todo global shuffle
/*
template <typename T>
void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
std::random_shuffle(memory_data_.begin(), memory_data_.end());
for (int64_t i = 0; i < memory_data_.size(); ++i) {
// todo get ins id
//std::string ins_id = memory_data_[i].ins_id;
// todo hash
int64_t hash_id = paddle::ps::local_random_engine()();
//int64_t hash_id = hash(ins_id);
int64_t node_id = hash_id % trainer_num_;
std::string str;
SerializeIns(memory_data_[i], str);
auto fleet_ptr = FleetWrapper::GetInstance();
auto ret = fleet_ptr->send_client2client_msg(0, node_id, str);
}
}
*/
template
class
InMemoryDataFeed
<
std
::
vector
<
MultiSlotType
>
>
;
template
<
typename
T
>
InMemoryDataFeed
<
T
>::
InMemoryDataFeed
()
{
cur_channel_
=
0
;
shuffled_ins_
=
nullptr
;
shuffled_ins_out_
=
nullptr
;
}
template
<
typename
T
>
bool
InMemoryDataFeed
<
T
>::
Start
()
{
DataFeed
::
CheckSetFileList
();
if
(
memory_data_
.
size
()
!=
0
)
{
CHECK_EQ
(
cur_channel_
,
0
);
shuffled_ins_
->
Extend
(
std
::
move
(
memory_data_
));
std
::
vector
<
T
>
().
swap
(
memory_data_
);
}
DataFeed
::
finish_start_
=
true
;
return
true
;
}
template
<
typename
T
>
int
InMemoryDataFeed
<
T
>::
Next
()
{
DataFeed
::
CheckStart
();
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
in_channel
=
nullptr
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
out_channel
=
nullptr
;
if
(
cur_channel_
==
0
)
{
in_channel
=
shuffled_ins_
;
out_channel
=
shuffled_ins_out_
;
}
else
{
in_channel
=
shuffled_ins_out_
;
out_channel
=
shuffled_ins_
;
}
CHECK
(
in_channel
!=
nullptr
);
CHECK
(
out_channel
!=
nullptr
);
int
index
=
0
;
T
instance
;
T
ins_vec
;
while
(
index
<
DataFeed
::
default_batch_size_
)
{
if
(
in_channel
->
Size
()
==
0
)
{
break
;
}
in_channel
->
Pop
(
instance
);
AddInstanceToInsVec
(
&
ins_vec
,
instance
,
index
++
);
out_channel
->
Push
(
std
::
move
(
instance
));
}
DataFeed
::
batch_size_
=
index
;
if
(
DataFeed
::
batch_size_
!=
0
)
{
PutToFeedVec
(
ins_vec
);
}
else
{
cur_channel_
=
1
-
cur_channel_
;
}
return
DataFeed
::
batch_size_
;
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
T
ins
;
DeserializeIns
(
ins
,
ins_str
);
shuffled_ins_
->
Push
(
std
::
move
(
ins
));
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
LoadIntoMemory
()
{
std
::
vector
<
T
>
local_vec
;
std
::
string
filename
;
while
(
DataFeed
::
PickOneFile
(
&
filename
))
{
int
err_no
=
0
;
PrivateQueueDataFeed
<
T
>::
fp_
=
fs_open_read
(
filename
,
&
err_no
,
PrivateQueueDataFeed
<
T
>::
pipe_command_
);
__fsetlocking
(
&*
PrivateQueueDataFeed
<
T
>::
fp_
,
FSETLOCKING_BYCALLER
);
T
instance
;
while
(
ParseOneInstanceFromPipe
(
&
instance
))
{
local_vec
.
push_back
(
instance
);
}
memory_data_
.
insert
(
memory_data_
.
end
(),
local_vec
.
begin
(),
local_vec
.
end
());
std
::
vector
<
T
>
().
swap
(
local_vec
);
}
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
LocalShuffle
()
{
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
}
// todo global shuffle
/*
template <typename T>
...
...
paddle/fluid/framework/device_worker.h
浏览文件 @
9bca1926
...
...
@@ -63,6 +63,7 @@ class PullDenseWorker {
static
std
::
shared_ptr
<
PullDenseWorker
>
s_instance_
;
std
::
shared_ptr
<
paddle
::
framework
::
FleetWrapper
>
fleet_ptr_
;
PullDenseWorkerParameter
param_
;
DownpourWorkerParameter
dwp_param_
;
Scope
*
root_scope_
;
bool
running_
;
...
...
paddle/fluid/framework/downpour_worker.cc
浏览文件 @
9bca1926
...
...
@@ -69,10 +69,16 @@ void DownpourWorker::Initialize(const TrainerDesc& desc) {
}
void
DownpourWorker
::
CollectLabelInfo
(
size_t
table_idx
)
{
auto
table
=
param_
.
sparse_table
(
table_idx
);
uint64_t
table_id
=
static_cast
<
uint64_t
>
(
param_
.
sparse_table
(
table_idx
).
table_id
());
uint64_t
table_id
=
static_cast
<
uint64_t
>
(
param_
.
program_config
(
0
).
pull_sparse_table_id
(
table_idx
));
TableParameter
table
;
for
(
auto
i
:
param_
.
sparse_table
())
{
if
(
i
.
table_id
()
==
table_id
)
{
table
=
i
;
break
;
}
}
auto
&
feature
=
features_
[
table_id
];
auto
&
feature_label
=
feature_labels_
[
table_id
];
feature_label
.
resize
(
feature
.
size
());
...
...
@@ -103,10 +109,17 @@ void DownpourWorker::CollectLabelInfo(size_t table_idx) {
}
void
DownpourWorker
::
FillSparseValue
(
size_t
table_idx
)
{
auto
table
=
param_
.
sparse_table
(
table_idx
);
uint64_t
table_id
=
static_cast
<
uint64_t
>
(
param_
.
program_config
(
0
).
pull_sparse_table_id
(
table_idx
));
TableParameter
table
;
for
(
auto
i
:
param_
.
sparse_table
())
{
if
(
i
.
table_id
()
==
table_id
)
{
table
=
i
;
break
;
}
}
uint64_t
table_id
=
static_cast
<
uint64_t
>
(
param_
.
sparse_table
(
table_idx
).
table_id
());
auto
&
fea_value
=
feature_values_
[
table_id
];
auto
fea_idx
=
0u
;
...
...
@@ -147,11 +160,20 @@ void DownpourWorker::TrainFiles() {
int
cur_batch
;
while
((
cur_batch
=
device_reader_
->
Next
())
>
0
)
{
// pull sparse here
for
(
size_t
i
=
0
;
i
<
param_
.
sparse_table_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
sparse_table
(
i
).
table_id
());
fleet_ptr_
->
PullSparseVarsSync
(
*
thread_scope_
,
tid
,
sparse_key_names_
[
tid
],
&
features_
[
tid
],
&
feature_values_
[
tid
],
param_
.
sparse_table
(
i
).
fea_dim
());
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
pull_sparse_table_id_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
program_config
(
0
).
pull_sparse_table_id
(
i
));
TableParameter
table
;
for
(
auto
i
:
param_
.
sparse_table
())
{
if
(
i
.
table_id
()
==
tid
)
{
table
=
i
;
break
;
}
}
fleet_ptr_
->
PullSparseVarsSync
(
*
thread_scope_
,
tid
,
sparse_key_names_
[
tid
],
&
features_
[
tid
],
&
feature_values_
[
tid
],
table
.
fea_dim
());
CollectLabelInfo
(
i
);
FillSparseValue
(
i
);
}
...
...
@@ -172,17 +194,27 @@ void DownpourWorker::TrainFiles() {
}
// push gradients here
for
(
size_t
i
=
0
;
i
<
param_
.
sparse_table_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
sparse_table
(
i
).
table_id
());
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
push_sparse_table_id_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
program_config
(
0
).
push_sparse_table_id
(
i
));
TableParameter
table
;
for
(
auto
i
:
param_
.
sparse_table
())
{
if
(
i
.
table_id
()
==
tid
)
{
table
=
i
;
break
;
}
}
fleet_ptr_
->
PushSparseVarsWithLabelAsync
(
*
thread_scope_
,
tid
,
features_
[
tid
],
feature_labels_
[
tid
],
sparse_key_names_
[
tid
],
sparse_grad_names_
[
tid
],
param_
.
sparse_table
(
i
).
emb_dim
(),
&
feature_grads_
[
tid
],
&
push_sparse_status_
);
sparse_key_names_
[
tid
],
sparse_grad_names_
[
tid
],
table
.
emb_dim
(),
&
feature_grads_
[
tid
],
&
push_sparse_status_
);
}
for
(
size_t
i
=
0
;
i
<
param_
.
dense_table_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
dense_table
(
i
).
table_id
());
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
program_config
(
0
).
push_dense_table_id
(
i
));
fleet_ptr_
->
PushDenseVarsAsync
(
*
thread_scope_
,
tid
,
dense_grad_names_
[
tid
],
&
push_sparse_status_
);
}
...
...
@@ -219,8 +251,10 @@ void DownpourWorker::TrainFiles() {
push_sparse_status_
.
resize
(
0
);
}
for
(
size_t
i
=
0
;
i
<
param_
.
dense_table_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
dense_table
(
i
).
table_id
());
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
program_config
(
0
).
push_dense_table_id
(
i
));
pull_dense_worker_
->
IncreaseThreadVersion
(
thread_id_
,
tid
);
}
...
...
paddle/fluid/framework/pull_dense_worker.cc
浏览文件 @
9bca1926
...
...
@@ -28,16 +28,26 @@ std::map<uint64_t, std::vector<std::string>>
void
PullDenseWorker
::
Initialize
(
const
TrainerDesc
&
param
)
{
running_
=
false
;
param_
=
param
.
pull_dense_param
();
dwp_param_
=
param
.
downpour_param
();
threshold_
=
param_
.
threshold
();
thread_num_
=
param_
.
device_num
();
sleep_time_ms_
=
param_
.
sleep_time_ms
();
for
(
size_t
i
=
0
;
i
<
param_
.
dense_table_size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
dwp_param_
.
program_config
(
0
).
pull_dense_table_id_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
dwp_param_
.
program_config
(
0
).
pull_dense_table_id
(
i
));
TableParameter
table
;
for
(
auto
i
:
param_
.
dense_table
())
{
if
(
i
.
table_id
()
==
tid
)
{
table
=
i
;
break
;
}
}
// setup dense variables for each table
int
var_num
=
param_
.
dense_table
(
i
).
dense_value_name_size
();
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
dense_table
(
i
).
table_id
());
int
var_num
=
table
.
dense_value_name_size
();
dense_value_names_
[
tid
].
resize
(
var_num
);
for
(
int
j
=
0
;
j
<
var_num
;
++
j
)
{
dense_value_names_
[
tid
][
j
]
=
param_
.
dense_table
(
i
)
.
dense_value_name
(
j
);
dense_value_names_
[
tid
][
j
]
=
table
.
dense_value_name
(
j
);
}
// setup training version for each table
training_versions_
[
tid
].
resize
(
thread_num_
,
0
);
...
...
@@ -82,8 +92,10 @@ int PullDenseWorker::Start() {
void
PullDenseWorker
::
Run
()
{
while
(
running_
)
{
pull_dense_status_
.
resize
(
0
);
for
(
size_t
i
=
0
;
i
<
param_
.
dense_table_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
dense_table
(
i
).
table_id
());
for
(
size_t
i
=
0
;
i
<
dwp_param_
.
program_config
(
0
).
pull_dense_table_id_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
dwp_param_
.
program_config
(
0
).
pull_dense_table_id
(
i
));
if
(
CheckUpdateParam
(
tid
))
{
fleet_ptr_
->
PullDenseVarsAsync
(
*
root_scope_
,
tid
,
dense_value_names_
[
tid
],
&
pull_dense_status_
);
...
...
paddle/fluid/framework/trainer_desc.proto
浏览文件 @
9bca1926
...
...
@@ -45,6 +45,15 @@ message DownpourWorkerParameter {
repeated
TableParameter
sparse_table
=
1
;
repeated
TableParameter
dense_table
=
2
;
repeated
string
skip_ops
=
3
;
repeated
ProgramConfig
program_config
=
4
;
}
message
ProgramConfig
{
required
string
program_id
=
1
;
repeated
int32
push_sparse_table_id
=
2
;
repeated
int32
push_dense_table_id
=
3
;
repeated
int32
pull_sparse_table_id
=
4
;
repeated
int32
pull_dense_table_id
=
5
;
}
message
PullDenseWorkerParameter
{
...
...
paddle/fluid/pybind/data_set_py.cc
浏览文件 @
9bca1926
...
...
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3)
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
...
...
@@ -43,7 +41,7 @@ namespace paddle {
namespace
pybind
{
void
BindDataset
(
py
::
module
*
m
)
{
py
::
class_
<
framework
::
Data
S
et
>
(
*
m
,
"Dataset"
)
py
::
class_
<
framework
::
Data
s
et
>
(
*
m
,
"Dataset"
)
.
def
(
py
::
init
([]()
{
return
std
::
unique_ptr
<
framework
::
Dataset
>
(
new
framework
::
Dataset
());
}))
...
...
@@ -53,7 +51,7 @@ void BindDataset(py::module* m) {
.
def
(
"set_data_feed_desc"
,
&
framework
::
Dataset
::
SetDataFeedDesc
)
.
def
(
"load_into_memory"
,
&
framework
::
Dataset
::
LoadIntoMemory
)
.
def
(
"local_shuffle"
,
&
framework
::
Dataset
::
LocalShuffle
)
.
def
(
"global_shuffle"
,
&
framework
::
Dataset
::
G
LobalShuffle
)
.
def
(
"global_shuffle"
,
&
framework
::
Dataset
::
G
lobalShuffle
);
}
}
// end namespace pybind
...
...
python/paddle/fluid/async_executor.py
浏览文件 @
9bca1926
...
...
@@ -118,12 +118,13 @@ class AsyncExecutor(object):
trainer
.
set_thread
(
thread_num
)
trainer
.
set_filelist
(
filelist
)
trainer
.
set_data_feed
(
data_feed
)
if
not
is_local
:
trainer
.
set_program_config
(
self
.
dist_desc
,
str
(
id
(
program
)))
with
open
(
"trainer_desc.proto"
,
"w"
)
as
fout
:
fout
.
write
(
trainer
.
_desc
())
# define a trainer and a device_worker here
self
.
executor
.
run_from_files
(
program_desc
,
trainer
.
_desc
(),
debug
,
str
(
id
(
program_desc
)))
trainer
.
_desc
(),
debug
)
'''
def run(self,
...
...
python/paddle/fluid/trainer_desc.py
浏览文件 @
9bca1926
...
...
@@ -78,3 +78,18 @@ class DistMultiTrainer(TrainerDesc):
worker_builder
=
DeviceWorkerFactory
()
device_worker
=
worker_builder
.
create_device_worker
(
"Downpour"
)
device_worker
.
gen_worker_desc
(
self
.
proto_desc
,
fleet_desc
)
def
set_program_config
(
self
,
fleet_desc
,
program_id
):
for
program_config
in
fleet_desc
.
trainer_param
.
program_config
:
if
program_config
.
program_id
==
program_id
:
pc
=
self
.
proto_desc
.
downpour_param
.
program_config
.
add
()
pc
.
program_id
=
program_config
.
program_id
for
i
in
program_config
.
push_sparse_table_id
:
pc
.
push_sparse_table_id
.
extend
([
i
])
for
i
in
program_config
.
push_dense_table_id
:
pc
.
push_dense_table_id
.
extend
([
i
])
for
i
in
program_config
.
pull_sparse_table_id
:
pc
.
pull_sparse_table_id
.
extend
([
i
])
for
i
in
program_config
.
pull_dense_table_id
:
pc
.
pull_dense_table_id
.
extend
([
i
])
break
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录