Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
c1650120
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c1650120
编写于
2月 02, 2019
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine device_worker and trainer code
test=develop
上级
8a335b50
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
318 addition
and
132 deletion
+318
-132
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+2
-2
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+3
-1
paddle/fluid/framework/async_executor.h
paddle/fluid/framework/async_executor.h
+0
-8
paddle/fluid/framework/device_worker.h
paddle/fluid/framework/device_worker.h
+4
-6
paddle/fluid/framework/device_worker_factory.cc
paddle/fluid/framework/device_worker_factory.cc
+0
-21
paddle/fluid/framework/device_worker_factory.h
paddle/fluid/framework/device_worker_factory.h
+50
-0
paddle/fluid/framework/dist_multi_trainer.cc
paddle/fluid/framework/dist_multi_trainer.cc
+6
-1
paddle/fluid/framework/downpour_worker.cc
paddle/fluid/framework/downpour_worker.cc
+29
-13
paddle/fluid/framework/fleet/CMakeLists.txt
paddle/fluid/framework/fleet/CMakeLists.txt
+1
-1
paddle/fluid/framework/fleet/fleet_wrapper.cc
paddle/fluid/framework/fleet/fleet_wrapper.cc
+42
-5
paddle/fluid/framework/fleet/fleet_wrapper.h
paddle/fluid/framework/fleet/fleet_wrapper.h
+4
-2
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+8
-6
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+2
-0
paddle/fluid/framework/pull_dense_worker.cc
paddle/fluid/framework/pull_dense_worker.cc
+11
-0
paddle/fluid/framework/trainer.cc
paddle/fluid/framework/trainer.cc
+0
-2
paddle/fluid/framework/trainer.h
paddle/fluid/framework/trainer.h
+1
-1
paddle/fluid/framework/trainer_desc.proto
paddle/fluid/framework/trainer_desc.proto
+0
-1
paddle/fluid/framework/trainer_factory.cc
paddle/fluid/framework/trainer_factory.cc
+0
-19
paddle/fluid/framework/trainer_factory.h
paddle/fluid/framework/trainer_factory.h
+47
-0
python/paddle/fluid/async_executor.py
python/paddle/fluid/async_executor.py
+12
-8
python/paddle/fluid/device_worker.py
python/paddle/fluid/device_worker.py
+75
-0
python/paddle/fluid/trainer_desc.py
python/paddle/fluid/trainer_desc.py
+21
-35
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
c1650120
...
@@ -203,7 +203,7 @@ if(WITH_PSLIB)
...
@@ -203,7 +203,7 @@ if(WITH_PSLIB)
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
DEPS op_registry device_context scope framework_proto
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table
trainer_desc_proto glog lod_rank_table
fleet_wrapper
feed_fetch_method graph_to_program_pass async_executor_proto
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper pslib_brpc pslib timer
)
variable_helper pslib_brpc pslib timer
)
else
()
else
()
...
@@ -212,7 +212,7 @@ else()
...
@@ -212,7 +212,7 @@ else()
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
DEPS op_registry device_context scope framework_proto
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table
trainer_desc_proto glog lod_rank_table
fleet_wrapper
feed_fetch_method graph_to_program_pass async_executor_proto
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper timer
)
variable_helper timer
)
endif
(
WITH_PSLIB
)
endif
(
WITH_PSLIB
)
...
...
paddle/fluid/framework/async_executor.cc
浏览文件 @
c1650120
...
@@ -26,7 +26,9 @@ limitations under the License. */
...
@@ -26,7 +26,9 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/pybind/pybind.h"
...
@@ -161,7 +163,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
...
@@ -161,7 +163,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
trainer
=
TrainerFactory
::
CreateTrainer
(
trainer_desc
.
class_name
());
trainer
=
TrainerFactory
::
CreateTrainer
(
trainer_desc
.
class_name
());
// initialize trainer
// initialize trainer
trainer
->
Initialize
(
trainer_desc
);
trainer
->
Initialize
(
trainer_desc
);
// trainer->SetRoo
tScope(root_scope_);
trainer
->
Se
tScope
(
root_scope_
);
trainer
->
SetDebug
(
debug
);
trainer
->
SetDebug
(
debug
);
// prepare training environment and helper environment
// prepare training environment and helper environment
trainer
->
InitTrainerEnv
(
main_program
,
place_
);
trainer
->
InitTrainerEnv
(
main_program
,
place_
);
...
...
paddle/fluid/framework/async_executor.h
浏览文件 @
c1650120
...
@@ -75,14 +75,6 @@ class AsyncExecutor {
...
@@ -75,14 +75,6 @@ class AsyncExecutor {
void
InitModel
();
void
InitModel
();
void
SaveModel
(
const
std
::
string
&
path
);
void
SaveModel
(
const
std
::
string
&
path
);
private:
void
CreateThreads
(
ExecutorThreadWorker
*
worker
,
const
ProgramDesc
&
main_program
,
const
std
::
shared_ptr
<
DataFeed
>&
reader
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
Scope
*
root_scope
,
const
int
thread_index
,
const
bool
debug
);
public:
public:
std
::
shared_ptr
<
paddle
::
framework
::
FleetWrapper
>
fleet_ptr_
;
std
::
shared_ptr
<
paddle
::
framework
::
FleetWrapper
>
fleet_ptr_
;
Scope
*
root_scope_
;
Scope
*
root_scope_
;
...
...
paddle/fluid/framework/device_worker.h
浏览文件 @
c1650120
...
@@ -39,12 +39,11 @@ namespace framework {
...
@@ -39,12 +39,11 @@ namespace framework {
class
PullDenseWorker
{
class
PullDenseWorker
{
public:
public:
PullDenseWorker
()
{}
virtual
~
PullDenseWorker
()
{}
virtual
~
PullDenseWorker
()
{}
virtual
void
Initialize
(
const
TrainerDesc
&
param
);
virtual
void
Initialize
(
const
TrainerDesc
&
param
);
int
Start
();
int
Start
();
void
Stop
();
void
Stop
();
void
SetScope
(
Scope
*
scope
)
{
root_scope_
=
scope
;
}
void
Set
Root
Scope
(
Scope
*
scope
)
{
root_scope_
=
scope
;
}
void
IncreaseThreadVersion
(
int
thread_id
,
uint64_t
table_id
);
void
IncreaseThreadVersion
(
int
thread_id
,
uint64_t
table_id
);
void
ResetThreadVersion
(
uint64_t
table_id
);
void
ResetThreadVersion
(
uint64_t
table_id
);
void
Wait
(
std
::
vector
<::
std
::
future
<
int32_t
>>*
status_vec
);
void
Wait
(
std
::
vector
<::
std
::
future
<
int32_t
>>*
status_vec
);
...
@@ -57,6 +56,7 @@ class PullDenseWorker {
...
@@ -57,6 +56,7 @@ class PullDenseWorker {
}
}
private:
private:
PullDenseWorker
()
:
root_scope_
(
NULL
)
{}
void
Run
();
void
Run
();
bool
CheckUpdateParam
(
uint64_t
table_id
);
bool
CheckUpdateParam
(
uint64_t
table_id
);
...
@@ -137,20 +137,18 @@ class HogwildWorker : public CPUWorkerBase {
...
@@ -137,20 +137,18 @@ class HogwildWorker : public CPUWorkerBase {
protected:
protected:
void
CreateThreadOperators
(
const
ProgramDesc
&
program
);
void
CreateThreadOperators
(
const
ProgramDesc
&
program
);
void
CreateThreadScope
(
const
ProgramDesc
&
program
);
void
CreateThreadScope
(
const
ProgramDesc
&
program
);
std
::
shared_ptr
<
DataFeed
>
thread_reader_
;
std
::
vector
<
std
::
string
>
op_names_
;
std
::
vector
<
std
::
string
>
op_names_
;
std
::
vector
<
OperatorBase
*>
ops_
;
std
::
vector
<
OperatorBase
*>
ops_
;
Scope
*
thread_scope_
;
Scope
*
thread_scope_
;
std
::
vector
<
std
::
string
>
fetch_var_names_
;
std
::
vector
<
std
::
string
>
fetch_var_names_
;
std
::
vector
<
std
::
vector
<
float
>>
fetch_values_
;
std
::
vector
<
std
::
vector
<
float
>>
fetch_values_
;
platform
::
Place
place_
;
};
};
class
DownpourWorker
:
public
HogwildWorker
{
class
DownpourWorker
:
public
HogwildWorker
{
public:
public:
DownpourWorker
()
{}
DownpourWorker
()
{}
virtual
~
DownpourWorker
()
{}
virtual
~
DownpourWorker
()
{}
virtual
void
Initilize
(
const
TrainerDesc
&
desc
);
virtual
void
Initi
a
lize
(
const
TrainerDesc
&
desc
);
virtual
void
TrainFiles
();
virtual
void
TrainFiles
();
protected:
protected:
...
@@ -163,7 +161,7 @@ class DownpourWorker : public HogwildWorker {
...
@@ -163,7 +161,7 @@ class DownpourWorker : public HogwildWorker {
private:
private:
DownpourWorkerParameter
param_
;
DownpourWorkerParameter
param_
;
// just save the value in param_ for easy access
// just save the value in param_ for easy access
std
::
string
label_var_name_
;
std
::
map
<
uint64_t
,
std
::
string
>
label_var_name_
;
std
::
map
<
uint64_t
,
std
::
vector
<
std
::
string
>>
sparse_key_names_
;
std
::
map
<
uint64_t
,
std
::
vector
<
std
::
string
>>
sparse_key_names_
;
std
::
map
<
uint64_t
,
std
::
vector
<
std
::
string
>>
sparse_value_names_
;
std
::
map
<
uint64_t
,
std
::
vector
<
std
::
string
>>
sparse_value_names_
;
std
::
map
<
uint64_t
,
std
::
vector
<
std
::
string
>>
sparse_grad_names_
;
std
::
map
<
uint64_t
,
std
::
vector
<
std
::
string
>>
sparse_grad_names_
;
...
...
paddle/fluid/framework/device_worker_factory.cc
浏览文件 @
c1650120
...
@@ -19,25 +19,6 @@ limitations under the License. */
...
@@ -19,25 +19,6 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
typedef
std
::
shared_ptr
<
DeviceWorker
>
(
*
Createdevice_workerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
Createdevice_workerFunction
>
device_workerMap
;
device_workerMap
g_device_worker_map
;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
std
::
string
DeviceWorkerFactory
::
DeviceWorkerTypeList
()
{
std
::
string
DeviceWorkerFactory
::
DeviceWorkerTypeList
()
{
std
::
string
device_worker_types
;
std
::
string
device_worker_types
;
...
@@ -59,7 +40,5 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
...
@@ -59,7 +40,5 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
return
g_device_worker_map
[
device_worker_class
]();
return
g_device_worker_map
[
device_worker_class
]();
}
}
REGISTER_DEVICE_WORKER_CLASS
(
HogwildWorker
);
REGISTER_DEVICE_WORKER_CLASS
(
DownpourWorker
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/device_worker_factory.h
0 → 100644
浏览文件 @
c1650120
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/device_worker.h"
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
DeviceWorker
>
(
*
Createdevice_workerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
Createdevice_workerFunction
>
device_workerMap
;
device_workerMap
g_device_worker_map
;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
class
DeviceWorkerFactory
{
public:
static
std
::
string
DeviceWorkerTypeList
();
static
std
::
shared_ptr
<
DeviceWorker
>
CreateDeviceWorker
(
std
::
string
device_worker_class
);
};
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/dist_multi_trainer.cc
浏览文件 @
c1650120
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -34,6 +35,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
...
@@ -34,6 +35,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
workers_
[
i
]
->
SetDeviceIndex
(
i
);
workers_
[
i
]
->
SetDeviceIndex
(
i
);
readers_
[
i
]
->
Init
(
trainer_desc
.
data_desc
());
readers_
[
i
]
->
Init
(
trainer_desc
.
data_desc
());
workers_
[
i
]
->
SetDataFeed
(
readers_
[
i
]);
workers_
[
i
]
->
SetDataFeed
(
readers_
[
i
]);
workers_
[
i
]
->
Initialize
(
trainer_desc
);
}
}
std
::
vector
<
std
::
string
>
filelist_vec
;
std
::
vector
<
std
::
string
>
filelist_vec
;
...
@@ -41,13 +43,15 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
...
@@ -41,13 +43,15 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
filelist_vec
.
push_back
(
trainer_desc
.
filelist
(
i
));
filelist_vec
.
push_back
(
trainer_desc
.
filelist
(
i
));
}
}
readers_
[
0
]
->
SetFileList
(
filelist_vec
);
fleet_ptr_
=
FleetWrapper
::
GetInstance
();
fleet_ptr_
=
FleetWrapper
::
GetInstance
();
pull_dense_worker_
=
PullDenseWorker
::
GetInstance
();
pull_dense_worker_
=
PullDenseWorker
::
GetInstance
();
pull_dense_worker_
->
Initialize
(
trainer_desc
);
pull_dense_worker_
->
Initialize
(
trainer_desc
);
}
}
void
DistMultiTrainer
::
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
{
void
DistMultiTrainer
::
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
{
pull_dense_worker_
->
SetScope
(
root_scope_
);
pull_dense_worker_
->
Set
Root
Scope
(
root_scope_
);
pull_dense_worker_
->
Start
();
pull_dense_worker_
->
Start
();
}
}
...
@@ -58,5 +62,6 @@ void DistMultiTrainer::Finalize() {
...
@@ -58,5 +62,6 @@ void DistMultiTrainer::Finalize() {
pull_dense_worker_
->
Stop
();
pull_dense_worker_
->
Stop
();
}
}
REGISTER_TRAINER_CLASS
(
DistMultiTrainer
);
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/downpour_worker.cc
浏览文件 @
c1650120
/* Copyright (c) 201
8
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 201
9
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
...
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
...
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_helper.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
void
DownpourWorker
::
Initilize
(
const
TrainerDesc
&
desc
)
{
void
DownpourWorker
::
Initi
a
lize
(
const
TrainerDesc
&
desc
)
{
param_
=
desc
.
downpour_param
();
param_
=
desc
.
downpour_param
();
for
(
size_t
i
=
0
;
i
<
param_
.
sparse_table_size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_
.
sparse_table_size
();
++
i
)
{
uint64_t
table_id
=
uint64_t
table_id
=
static_cast
<
uint64_t
>
(
param_
.
sparse_table
(
i
).
table_id
());
static_cast
<
uint64_t
>
(
param_
.
sparse_table
(
i
).
table_id
());
...
@@ -37,6 +37,7 @@ void DownpourWorker::Initilize(const TrainerDesc& desc) {
...
@@ -37,6 +37,7 @@ void DownpourWorker::Initilize(const TrainerDesc& desc) {
for
(
size_t
j
=
0
;
j
<
table
.
sparse_grad_name_size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
table
.
sparse_grad_name_size
();
++
j
)
{
sparse_grad_names_
[
table_id
][
j
]
=
table
.
sparse_grad_name
(
j
);
sparse_grad_names_
[
table_id
][
j
]
=
table
.
sparse_grad_name
(
j
);
}
}
label_var_name_
[
table_id
]
=
table
.
label_var_name
();
}
}
for
(
size_t
i
=
0
;
i
<
param_
.
dense_table_size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_
.
dense_table_size
();
++
i
)
{
...
@@ -56,15 +57,18 @@ void DownpourWorker::Initilize(const TrainerDesc& desc) {
...
@@ -56,15 +57,18 @@ void DownpourWorker::Initilize(const TrainerDesc& desc) {
for
(
size_t
i
=
0
;
i
<
param_
.
skip_ops_size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_
.
skip_ops_size
();
++
i
)
{
skip_ops_
[
i
]
=
param_
.
skip_ops
(
i
);
skip_ops_
[
i
]
=
param_
.
skip_ops
(
i
);
}
}
skip_ops_
.
resize
(
param_
.
skip_ops_size
());
label_var_name_
=
param_
.
label_var_name
();
}
}
void
DownpourWorker
::
CollectLabelInfo
(
size_t
table_id
)
{
void
DownpourWorker
::
CollectLabelInfo
(
size_t
table_idx
)
{
auto
table
=
param_
.
sparse_table
(
table_idx
);
uint64_t
table_id
=
static_cast
<
uint64_t
>
(
param_
.
sparse_table
(
table_idx
).
table_id
());
auto
&
feature
=
features_
[
table_id
];
auto
&
feature
=
features_
[
table_id
];
auto
&
feature_label
=
feature_labels_
[
table_id
];
auto
&
feature_label
=
feature_labels_
[
table_id
];
feature_label
.
resize
(
feature
.
size
());
feature_label
.
resize
(
feature
.
size
());
Variable
*
var
=
thread_scope_
->
FindVar
(
label_var_name_
);
Variable
*
var
=
thread_scope_
->
FindVar
(
label_var_name_
[
table_id
]
);
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
int64_t
*
label_ptr
=
tensor
->
data
<
int64_t
>
();
int64_t
*
label_ptr
=
tensor
->
data
<
int64_t
>
();
...
@@ -75,13 +79,14 @@ void DownpourWorker::CollectLabelInfo(size_t table_id) {
...
@@ -75,13 +79,14 @@ void DownpourWorker::CollectLabelInfo(size_t table_id) {
int64_t
*
ids
=
tensor
->
data
<
int64_t
>
();
int64_t
*
ids
=
tensor
->
data
<
int64_t
>
();
int
fea_idx
=
0
;
int
fea_idx
=
0
;
// tensor->lod()[0].size() == batch_size + 1
// tensor->lod()[0].size() == batch_size + 1
for
(
auto
ins_idx
=
0u
;
ins_idx
<
tensor
->
lod
()[
0
].
size
()
-
1
;
++
ins
_idx
)
{
for
(
auto
lod_idx
=
1u
;
lod_idx
<
tensor
->
lod
()[
0
].
size
();
++
lod
_idx
)
{
for
(;
fea_idx
<
tensor
->
lod
()[
0
][
ins
_idx
];
++
fea_idx
)
{
for
(;
fea_idx
<
tensor
->
lod
()[
0
][
lod
_idx
];
++
fea_idx
)
{
// should be skipped feasign defined in protobuf
// should be skipped feasign defined in protobuf
if
(
ids
[
fea_idx
]
==
0u
)
{
if
(
ids
[
fea_idx
]
==
0u
)
{
continue
;
continue
;
}
}
feature_label
[
global_index
++
]
=
static_cast
<
float
>
(
label_ptr
[
ins_idx
]);
feature_label
[
global_index
++
]
=
static_cast
<
float
>
(
label_ptr
[
lod_idx
-
1
]);
}
}
}
}
}
}
...
@@ -128,10 +133,10 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
...
@@ -128,10 +133,10 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
void
DownpourWorker
::
TrainFiles
()
{
void
DownpourWorker
::
TrainFiles
()
{
platform
::
SetNumThreads
(
1
);
platform
::
SetNumThreads
(
1
);
thread
_reader_
->
Start
();
device
_reader_
->
Start
();
int
batch_cnt
=
0
;
int
batch_cnt
=
0
;
int
cur_batch
;
int
cur_batch
;
while
((
cur_batch
=
thread
_reader_
->
Next
())
>
0
)
{
while
((
cur_batch
=
device
_reader_
->
Next
())
>
0
)
{
// pull sparse here
// pull sparse here
for
(
size_t
i
=
0
;
i
<
param_
.
sparse_table_size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_
.
sparse_table_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
sparse_table
(
i
).
table_id
());
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
sparse_table
(
i
).
table_id
());
...
@@ -144,7 +149,16 @@ void DownpourWorker::TrainFiles() {
...
@@ -144,7 +149,16 @@ void DownpourWorker::TrainFiles() {
// do computation here
// do computation here
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
*
thread_scope_
,
place_
);
bool
need_skip
=
false
;
for
(
auto
t
=
0u
;
t
<
skip_ops_
.
size
();
++
t
)
{
if
(
op
->
Type
().
find
(
skip_ops_
[
t
])
!=
std
::
string
::
npos
)
{
need_skip
=
true
;
break
;
}
}
if
(
!
need_skip
)
{
op
->
Run
(
*
thread_scope_
,
place_
);
}
}
}
// push gradients here
// push gradients here
...
@@ -198,10 +212,12 @@ void DownpourWorker::TrainFiles() {
...
@@ -198,10 +212,12 @@ void DownpourWorker::TrainFiles() {
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
dense_table
(
i
).
table_id
());
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
dense_table
(
i
).
table_id
());
pull_dense_worker_
->
IncreaseThreadVersion
(
thread_id_
,
tid
);
pull_dense_worker_
->
IncreaseThreadVersion
(
thread_id_
,
tid
);
}
}
thread_scope_
->
DropKids
();
thread_scope_
->
DropKids
();
++
batch_cnt
;
++
batch_cnt
;
}
}
}
}
REGISTER_DEVICE_WORKER_CLASS
(
DownpourWorker
);
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/fleet/CMakeLists.txt
浏览文件 @
c1650120
cc_library
(
fleet_wrapper SRCS fleet_wrapper.cc
)
cc_library
(
fleet_wrapper SRCS fleet_wrapper.cc
DEPS pslib_brpc pslib
)
paddle/fluid/framework/fleet/fleet_wrapper.cc
浏览文件 @
c1650120
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
...
@@ -19,10 +33,16 @@ namespace framework {
...
@@ -19,10 +33,16 @@ namespace framework {
const
uint32_t
MAX_FEASIGN_NUM
=
1024
*
100
*
100
;
const
uint32_t
MAX_FEASIGN_NUM
=
1024
*
100
*
100
;
std
::
shared_ptr
<
FleetWrapper
>
FleetWrapper
::
s_instance_
=
NULL
;
std
::
shared_ptr
<
FleetWrapper
>
FleetWrapper
::
s_instance_
=
NULL
;
bool
FleetWrapper
::
is_initialized_
=
false
;
#ifdef PADDLE_WITH_PSLIB
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
FleetWrapper
::
pslib_ptr_
=
NULL
;
#endif
void
FleetWrapper
::
InitServer
(
const
std
::
string
&
dist_desc
,
int
index
)
{
void
FleetWrapper
::
InitServer
(
const
std
::
string
&
dist_desc
,
int
index
)
{
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
if
(
!
is_initialized_
)
{
if
(
!
is_initialized_
)
{
LOG
(
WARNING
)
<<
"Going to init server"
;
pslib_ptr_
=
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
(
pslib_ptr_
=
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
(
new
paddle
::
distributed
::
PSlib
());
new
paddle
::
distributed
::
PSlib
());
pslib_ptr_
->
init_server
(
dist_desc
,
index
);
pslib_ptr_
->
init_server
(
dist_desc
,
index
);
...
@@ -38,6 +58,7 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
...
@@ -38,6 +58,7 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
int
node_num
,
int
index
)
{
int
node_num
,
int
index
)
{
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
if
(
!
is_initialized_
)
{
if
(
!
is_initialized_
)
{
LOG
(
WARNING
)
<<
"Going to init server"
;
pslib_ptr_
=
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
(
pslib_ptr_
=
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
(
new
paddle
::
distributed
::
PSlib
());
new
paddle
::
distributed
::
PSlib
());
pslib_ptr_
->
init_worker
(
dist_desc
,
pslib_ptr_
->
init_worker
(
dist_desc
,
...
@@ -52,12 +73,14 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
...
@@ -52,12 +73,14 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
void
FleetWrapper
::
StopServer
()
{
void
FleetWrapper
::
StopServer
()
{
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
LOG
(
WARNING
)
<<
"Going to stop server"
;
pslib_ptr_
->
stop_server
();
pslib_ptr_
->
stop_server
();
#endif
#endif
}
}
uint64_t
FleetWrapper
::
RunServer
()
{
uint64_t
FleetWrapper
::
RunServer
()
{
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
LOG
(
WARNING
)
<<
"Going to run server"
;
return
pslib_ptr_
->
run_server
();
return
pslib_ptr_
->
run_server
();
#else
#else
return
0
;
return
0
;
...
@@ -67,6 +90,7 @@ uint64_t FleetWrapper::RunServer() {
...
@@ -67,6 +90,7 @@ uint64_t FleetWrapper::RunServer() {
void
FleetWrapper
::
GatherServers
(
const
std
::
vector
<
uint64_t
>&
host_sign_list
,
void
FleetWrapper
::
GatherServers
(
const
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
)
{
int
node_num
)
{
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
LOG
(
WARNING
)
<<
"Going to gather server ips"
;
pslib_ptr_
->
gather_servers
(
const_cast
<
uint64_t
*>
(
host_sign_list
.
data
()),
pslib_ptr_
->
gather_servers
(
const_cast
<
uint64_t
*>
(
host_sign_list
.
data
()),
node_num
);
node_num
);
#endif
#endif
...
@@ -122,13 +146,13 @@ void FleetWrapper::PullDenseVarsAsync(
...
@@ -122,13 +146,13 @@ void FleetWrapper::PullDenseVarsAsync(
std
::
vector
<::
std
::
future
<
int32_t
>>*
pull_dense_status
)
{
std
::
vector
<::
std
::
future
<
int32_t
>>*
pull_dense_status
)
{
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
std
::
vector
<
paddle
::
ps
::
Region
>
regions
;
std
::
vector
<
paddle
::
ps
::
Region
>
regions
;
regions
.
res
erv
e
(
var_names
.
size
());
regions
.
res
iz
e
(
var_names
.
size
());
for
(
auto
&
t
:
var_names
)
{
for
(
auto
i
=
0u
;
i
<
var_names
.
size
();
++
i
)
{
Variable
*
var
=
scope
.
FindVar
(
t
);
Variable
*
var
=
scope
.
FindVar
(
var_names
[
i
]
);
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
float
*
w
=
tensor
->
data
<
float
>
();
float
*
w
=
tensor
->
data
<
float
>
();
paddle
::
ps
::
Region
reg
(
w
,
tensor
->
numel
());
paddle
::
ps
::
Region
reg
(
w
,
tensor
->
numel
());
regions
.
emplace_back
(
std
::
move
(
reg
)
);
regions
[
i
]
=
std
::
move
(
reg
);
}
}
auto
status
=
auto
status
=
pslib_ptr_
->
_worker_ptr
->
pull_dense
(
regions
.
data
(),
regions
.
size
(),
tid
);
pslib_ptr_
->
_worker_ptr
->
pull_dense
(
regions
.
data
(),
regions
.
size
(),
tid
);
...
@@ -186,7 +210,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
...
@@ -186,7 +210,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
int
offset
=
2
;
int
offset
=
2
;
uint64_t
fea_idx
=
0u
;
uint64_t
fea_idx
=
0u
;
for
(
size_t
i
=
0
;
i
<
sparse_key_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
sparse_key_names
.
size
();
++
i
)
{
Variable
*
g_var
=
scope
.
FindVar
(
sparse_key_names
[
i
]);
LOG
(
WARNING
)
<<
"sparse key names["
<<
i
<<
"]: "
<<
sparse_key_names
[
i
];
LOG
(
WARNING
)
<<
"sparse grad names["
<<
i
<<
"]: "
<<
sparse_grad_names
[
i
];
Variable
*
g_var
=
scope
.
FindVar
(
sparse_grad_names
[
i
]);
CHECK
(
g_var
!=
nullptr
)
<<
"var["
<<
sparse_grad_names
[
i
]
<<
"] not found"
;
LoDTensor
*
g_tensor
=
g_var
->
GetMutable
<
LoDTensor
>
();
LoDTensor
*
g_tensor
=
g_var
->
GetMutable
<
LoDTensor
>
();
if
(
g_tensor
==
NULL
)
{
if
(
g_tensor
==
NULL
)
{
LOG
(
ERROR
)
<<
"var["
<<
sparse_key_names
[
i
]
<<
"] not found"
;
LOG
(
ERROR
)
<<
"var["
<<
sparse_key_names
[
i
]
<<
"] not found"
;
...
@@ -201,16 +228,26 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
...
@@ -201,16 +228,26 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
exit
(
-
1
);
exit
(
-
1
);
}
}
int
len
=
tensor
->
numel
();
int
len
=
tensor
->
numel
();
LOG
(
WARNING
)
<<
" tensor len: "
<<
len
;
int64_t
*
ids
=
tensor
->
data
<
int64_t
>
();
int64_t
*
ids
=
tensor
->
data
<
int64_t
>
();
push_values
->
resize
(
fea_keys
.
size
()
+
1
);
for
(
auto
&
t
:
*
push_values
)
{
t
.
resize
(
emb_dim
+
offset
);
}
for
(
auto
id_idx
=
0u
;
id_idx
<
len
;
++
id_idx
)
{
for
(
auto
id_idx
=
0u
;
id_idx
<
len
;
++
id_idx
)
{
if
(
ids
[
id_idx
]
==
0
)
{
if
(
ids
[
id_idx
]
==
0
)
{
g
+=
emb_dim
;
g
+=
emb_dim
;
continue
;
continue
;
}
}
LOG
(
WARNING
)
<<
"going to memcpy"
;
memcpy
((
*
push_values
)[
fea_idx
].
data
()
+
offset
,
g
,
memcpy
((
*
push_values
)[
fea_idx
].
data
()
+
offset
,
g
,
sizeof
(
float
)
*
emb_dim
);
sizeof
(
float
)
*
emb_dim
);
LOG
(
WARNING
)
<<
"show"
;
(
*
push_values
)[
fea_idx
][
0
]
=
1.0
f
;
(
*
push_values
)[
fea_idx
][
0
]
=
1.0
f
;
LOG
(
WARNING
)
<<
"click"
;
(
*
push_values
)[
fea_idx
][
1
]
=
static_cast
<
float
>
(
fea_labels
[
fea_idx
]);
(
*
push_values
)[
fea_idx
][
1
]
=
static_cast
<
float
>
(
fea_labels
[
fea_idx
]);
LOG
(
WARNING
)
<<
"offset"
;
g
+=
emb_dim
;
g
+=
emb_dim
;
fea_idx
++
;
fea_idx
++
;
}
}
...
...
paddle/fluid/framework/fleet/fleet_wrapper.h
浏览文件 @
c1650120
...
@@ -47,7 +47,6 @@ namespace framework {
...
@@ -47,7 +47,6 @@ namespace framework {
class
FleetWrapper
{
class
FleetWrapper
{
public:
public:
FleetWrapper
()
{}
virtual
~
FleetWrapper
()
{}
virtual
~
FleetWrapper
()
{}
// Pull sparse variables from server in Sync mode
// Pull sparse variables from server in Sync mode
...
@@ -122,8 +121,11 @@ class FleetWrapper {
...
@@ -122,8 +121,11 @@ class FleetWrapper {
static
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
pslib_ptr_
;
static
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
pslib_ptr_
;
#endif
#endif
private:
FleetWrapper
()
{}
protected:
protected:
bool
is_initialized_
;
static
bool
is_initialized_
;
DISABLE_COPY_AND_ASSIGN
(
FleetWrapper
);
DISABLE_COPY_AND_ASSIGN
(
FleetWrapper
);
};
};
...
...
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
c1650120
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/cpu_helper.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -50,9 +51,9 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc& program) {
...
@@ -50,9 +51,9 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc& program) {
void
HogwildWorker
::
BindingDataFeedMemory
()
{
void
HogwildWorker
::
BindingDataFeedMemory
()
{
const
std
::
vector
<
std
::
string
>&
input_feed
=
const
std
::
vector
<
std
::
string
>&
input_feed
=
thread
_reader_
->
GetUseSlotAlias
();
device
_reader_
->
GetUseSlotAlias
();
for
(
auto
name
:
input_feed
)
{
for
(
auto
name
:
input_feed
)
{
thread
_reader_
->
AddFeedVar
(
thread_scope_
->
Var
(
name
),
name
);
device
_reader_
->
AddFeedVar
(
thread_scope_
->
Var
(
name
),
name
);
}
}
}
}
...
@@ -63,7 +64,7 @@ void HogwildWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
...
@@ -63,7 +64,7 @@ void HogwildWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
void
HogwildWorker
::
TrainFilesWithProfiler
()
{
void
HogwildWorker
::
TrainFilesWithProfiler
()
{
platform
::
SetNumThreads
(
1
);
platform
::
SetNumThreads
(
1
);
thread
_reader_
->
Start
();
device
_reader_
->
Start
();
std
::
vector
<
double
>
op_total_time
;
std
::
vector
<
double
>
op_total_time
;
std
::
vector
<
std
::
string
>
op_name
;
std
::
vector
<
std
::
string
>
op_name
;
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
...
@@ -79,7 +80,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
...
@@ -79,7 +80,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
int
cur_batch
;
int
cur_batch
;
int
batch_cnt
=
0
;
int
batch_cnt
=
0
;
timeline
.
Start
();
timeline
.
Start
();
while
((
cur_batch
=
thread
_reader_
->
Next
())
>
0
)
{
while
((
cur_batch
=
device
_reader_
->
Next
())
>
0
)
{
timeline
.
Pause
();
timeline
.
Pause
();
read_time
+=
timeline
.
ElapsedSec
();
read_time
+=
timeline
.
ElapsedSec
();
total_time
+=
timeline
.
ElapsedSec
();
total_time
+=
timeline
.
ElapsedSec
();
...
@@ -115,10 +116,10 @@ void HogwildWorker::TrainFiles() {
...
@@ -115,10 +116,10 @@ void HogwildWorker::TrainFiles() {
platform
::
SetNumThreads
(
1
);
platform
::
SetNumThreads
(
1
);
// how to accumulate fetched values here
// how to accumulate fetched values here
thread
_reader_
->
Start
();
device
_reader_
->
Start
();
int
cur_batch
;
int
cur_batch
;
int
batch_cnt
=
0
;
int
batch_cnt
=
0
;
while
((
cur_batch
=
thread
_reader_
->
Next
())
>
0
)
{
while
((
cur_batch
=
device
_reader_
->
Next
())
>
0
)
{
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
*
thread_scope_
,
place_
);
op
->
Run
(
*
thread_scope_
,
place_
);
}
}
...
@@ -128,5 +129,6 @@ void HogwildWorker::TrainFiles() {
...
@@ -128,5 +129,6 @@ void HogwildWorker::TrainFiles() {
}
}
}
}
REGISTER_DEVICE_WORKER_CLASS
(
HogwildWorker
);
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
c1650120
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -65,5 +66,6 @@ void MultiTrainer::Finalize() {
...
@@ -65,5 +66,6 @@ void MultiTrainer::Finalize() {
}
}
}
}
REGISTER_TRAINER_CLASS
(
MultiTrainer
);
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/pull_dense_worker.cc
浏览文件 @
c1650120
...
@@ -20,24 +20,31 @@ namespace framework {
...
@@ -20,24 +20,31 @@ namespace framework {
std
::
shared_ptr
<
PullDenseWorker
>
PullDenseWorker
::
s_instance_
=
NULL
;
std
::
shared_ptr
<
PullDenseWorker
>
PullDenseWorker
::
s_instance_
=
NULL
;
void
PullDenseWorker
::
Initialize
(
const
TrainerDesc
&
param
)
{
void
PullDenseWorker
::
Initialize
(
const
TrainerDesc
&
param
)
{
LOG
(
WARNING
)
<<
"going to initialize pull dense worker"
;
running_
=
false
;
running_
=
false
;
param_
=
param
.
pull_dense_param
();
param_
=
param
.
pull_dense_param
();
threshold_
=
param_
.
threshold
();
threshold_
=
param_
.
threshold
();
thread_num_
=
param_
.
device_num
();
thread_num_
=
param_
.
device_num
();
sleep_time_ms_
=
param_
.
sleep_time_ms
();
sleep_time_ms_
=
param_
.
sleep_time_ms
();
LOG
(
WARNING
)
<<
"dense table size: "
<<
param_
.
dense_table_size
();
LOG
(
WARNING
)
<<
"thread num: "
<<
thread_num_
;
for
(
size_t
i
=
0
;
i
<
param_
.
dense_table_size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_
.
dense_table_size
();
++
i
)
{
// setup dense variables for each table
// setup dense variables for each table
int
var_num
=
param_
.
dense_table
(
i
).
dense_value_name_size
();
int
var_num
=
param_
.
dense_table
(
i
).
dense_value_name_size
();
LOG
(
WARNING
)
<<
"var num: "
<<
var_num
;
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
dense_table
(
i
).
table_id
());
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
dense_table
(
i
).
table_id
());
dense_value_names_
[
tid
].
resize
(
var_num
);
dense_value_names_
[
tid
].
resize
(
var_num
);
for
(
int
j
=
0
;
j
<
var_num
;
++
j
)
{
for
(
int
j
=
0
;
j
<
var_num
;
++
j
)
{
dense_value_names_
[
tid
][
j
]
=
param_
.
dense_table
(
i
).
dense_value_name
(
j
);
dense_value_names_
[
tid
][
j
]
=
param_
.
dense_table
(
i
).
dense_value_name
(
j
);
LOG
(
WARNING
)
<<
"dense value names "
<<
j
<<
" "
<<
dense_value_names_
[
tid
][
j
];
}
}
// setup training version for each table
// setup training version for each table
training_versions_
[
tid
].
resize
(
thread_num_
,
0
);
training_versions_
[
tid
].
resize
(
thread_num_
,
0
);
last_versions_
[
tid
]
=
0
;
last_versions_
[
tid
]
=
0
;
current_version_
[
tid
]
=
0
;
current_version_
[
tid
]
=
0
;
}
}
LOG
(
WARNING
)
<<
"initialize pull dense worker done."
;
}
}
void
PullDenseWorker
::
Wait
(
std
::
vector
<::
std
::
future
<
int32_t
>>*
status_vec
)
{
void
PullDenseWorker
::
Wait
(
std
::
vector
<::
std
::
future
<
int32_t
>>*
status_vec
)
{
...
@@ -56,6 +63,7 @@ void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) {
...
@@ -56,6 +63,7 @@ void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) {
<<
" Times"
;
<<
" Times"
;
exit
(
-
1
);
exit
(
-
1
);
}
}
status_vec
->
resize
(
0
);
}
}
void
PullDenseWorker
::
Stop
()
{
void
PullDenseWorker
::
Stop
()
{
...
@@ -90,7 +98,10 @@ void PullDenseWorker::Run() {
...
@@ -90,7 +98,10 @@ void PullDenseWorker::Run() {
}
}
void
PullDenseWorker
::
IncreaseThreadVersion
(
int
thread_id
,
uint64_t
table_id
)
{
void
PullDenseWorker
::
IncreaseThreadVersion
(
int
thread_id
,
uint64_t
table_id
)
{
LOG
(
WARNING
)
<<
"increase thread version input: "
<<
thread_id
<<
" table id "
<<
table_id
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_for_version_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_for_version_
);
LOG
(
WARNING
)
<<
"going to increase"
;
training_versions_
[
table_id
][
thread_id
]
++
;
training_versions_
[
table_id
][
thread_id
]
++
;
}
}
...
...
paddle/fluid/framework/trainer.cc
浏览文件 @
c1650120
...
@@ -19,7 +19,5 @@ namespace framework {
...
@@ -19,7 +19,5 @@ namespace framework {
void
TrainerBase
::
SetScope
(
Scope
*
root_scope
)
{
root_scope_
=
root_scope
;
}
void
TrainerBase
::
SetScope
(
Scope
*
root_scope
)
{
root_scope_
=
root_scope
;
}
void
TrainerBase
::
Initialize
(
const
TrainerDesc
&
trainer_desc
)
{
return
;
}
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/trainer.h
浏览文件 @
c1650120
...
@@ -39,8 +39,8 @@ class TrainerBase {
...
@@ -39,8 +39,8 @@ class TrainerBase {
virtual
~
TrainerBase
()
{}
virtual
~
TrainerBase
()
{}
// model memory are hosted in root_scope
// model memory are hosted in root_scope
void
SetScope
(
Scope
*
root_scope
);
void
SetScope
(
Scope
*
root_scope
);
void
Initialize
(
const
TrainerDesc
&
trainer_desc
);
void
SetDebug
(
const
bool
debug
)
{
debug_
=
debug
;
}
void
SetDebug
(
const
bool
debug
)
{
debug_
=
debug
;
}
virtual
void
Initialize
(
const
TrainerDesc
&
trainer_desc
)
=
0
;
virtual
void
InitTrainerEnv
(
const
ProgramDesc
&
main_program
,
virtual
void
InitTrainerEnv
(
const
ProgramDesc
&
main_program
,
const
platform
::
Place
&
place
)
=
0
;
const
platform
::
Place
&
place
)
=
0
;
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
=
0
;
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
=
0
;
...
...
paddle/fluid/framework/trainer_desc.proto
浏览文件 @
c1650120
...
@@ -43,7 +43,6 @@ message DownpourWorkerParameter {
...
@@ -43,7 +43,6 @@ message DownpourWorkerParameter {
repeated
TableParameter
sparse_table
=
1
;
repeated
TableParameter
sparse_table
=
1
;
repeated
TableParameter
dense_table
=
2
;
repeated
TableParameter
dense_table
=
2
;
repeated
string
skip_ops
=
3
;
repeated
string
skip_ops
=
3
;
optional
string
label_var_name
=
4
;
}
}
message
PullDenseWorkerParameter
{
message
PullDenseWorkerParameter
{
...
...
paddle/fluid/framework/trainer_factory.cc
浏览文件 @
c1650120
...
@@ -21,23 +21,6 @@ limitations under the License. */
...
@@ -21,23 +21,6 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
typedef
std
::
shared_ptr
<
TrainerBase
>
(
*
CreatetrainerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
CreatetrainerFunction
>
trainerMap
;
trainerMap
g_trainer_map
;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
std
::
string
TrainerFactory
::
TrainerTypeList
()
{
std
::
string
TrainerFactory
::
TrainerTypeList
()
{
std
::
string
trainer_types
;
std
::
string
trainer_types
;
...
@@ -58,7 +41,5 @@ std::shared_ptr<TrainerBase> TrainerFactory::CreateTrainer(
...
@@ -58,7 +41,5 @@ std::shared_ptr<TrainerBase> TrainerFactory::CreateTrainer(
return
g_trainer_map
[
trainer_class
]();
return
g_trainer_map
[
trainer_class
]();
}
}
REGISTER_TRAINER_CLASS
(
MultiTrainer
);
REGISTER_TRAINER_CLASS
(
DistMultiTrainer
);
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/fluid/framework/trainer_factory.h
0 → 100644
浏览文件 @
c1650120
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/trainer.h"
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
TrainerBase
>
(
*
CreatetrainerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
CreatetrainerFunction
>
trainerMap
;
trainerMap
g_trainer_map
;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
class
TrainerFactory
{
public:
static
std
::
string
TrainerTypeList
();
static
std
::
shared_ptr
<
TrainerBase
>
CreateTrainer
(
std
::
string
trainer_class
);
};
}
// namespace framework
}
// namespace paddle
python/paddle/fluid/async_executor.py
浏览文件 @
c1650120
...
@@ -110,15 +110,17 @@ class AsyncExecutor(object):
...
@@ -110,15 +110,17 @@ class AsyncExecutor(object):
is_local
=
self
.
instance
==
None
is_local
=
self
.
instance
==
None
trainer
=
None
trainer
=
None
if
is_local
:
if
is_local
:
trainer
=
MultiTrainer
(
data_feed
=
data_feed
,
worker
=
"Hogwild"
)
trainer
=
MultiTrainer
()
else
:
else
:
trainer
=
DistMultiTrainer
(
trainer
=
DistMultiTrainer
()
data_feed
,
worker
=
"Downpour"
,
fleet_desc
=
self
.
dist_desc
)
trainer
.
gen_trainer_desc
(
dataset
=
data_feed
,
fleet_desc
=
self
.
dist_desc
,
worker
=
"downpour"
)
# define a trainer and a device_worker here
trainer
.
set_thread
(
thread_num
)
trainer
.
set_thread
(
thread_num
)
trainer
.
set_filelist
(
filelist
)
trainer
.
set_filelist
(
filelist
)
trainer
.
set_data_feed
(
data_feed
)
trainer
.
set_data_feed
(
data_feed
)
with
open
(
"trainer_desc.proto"
,
"w"
)
as
fout
:
fout
.
write
(
trainer
.
_desc
())
# define a trainer and a device_worker here
self
.
executor
.
run_from_files
(
program_desc
,
trainer
.
_desc
(),
debug
)
self
.
executor
.
run_from_files
(
program_desc
,
trainer
.
_desc
(),
debug
)
'''
'''
...
@@ -284,8 +286,9 @@ class AsyncExecutor(object):
...
@@ -284,8 +286,9 @@ class AsyncExecutor(object):
raise
ValueError
(
raise
ValueError
(
'instance is None, please run config_distributed_nodes init instance'
'instance is None, please run config_distributed_nodes init instance'
)
)
self
.
init_desc
=
init_desc
self
.
dist_desc_str
=
text_format
.
MessageToString
(
dist_desc
)
self
.
executor
.
init_server
(
dist_desc
,
self
.
instance
.
_rankid
)
self
.
dist_desc
=
dist_desc
self
.
executor
.
init_server
(
self
.
dist_desc_str
,
self
.
instance
.
_rankid
)
ip
=
self
.
executor
.
start_server
()
ip
=
self
.
executor
.
start_server
()
self
.
instance
.
set_ip
(
ip
)
self
.
instance
.
set_ip
(
ip
)
self
.
instance
.
barrier_all
()
#wait all server start
self
.
instance
.
barrier_all
()
#wait all server start
...
@@ -306,6 +309,7 @@ class AsyncExecutor(object):
...
@@ -306,6 +309,7 @@ class AsyncExecutor(object):
'instance is None, please run config_distributed_nodes init instance'
'instance is None, please run config_distributed_nodes init instance'
)
)
self
.
dist_desc_str
=
text_format
.
MessageToString
(
dist_desc
)
self
.
dist_desc
=
dist_desc
self
.
dist_desc
=
dist_desc
place
=
core
.
CPUPlace
()
place
=
core
.
CPUPlace
()
executor
=
Executor
(
place
)
executor
=
Executor
(
place
)
...
@@ -313,7 +317,7 @@ class AsyncExecutor(object):
...
@@ -313,7 +317,7 @@ class AsyncExecutor(object):
self
.
instance
.
barrier_all
()
#wait all server start
self
.
instance
.
barrier_all
()
#wait all server start
ips
=
self
.
instance
.
gather_ips
()
ips
=
self
.
instance
.
gather_ips
()
self
.
executor
.
init_worker
(
dist_desc
,
ips
,
self
.
executor
.
init_worker
(
self
.
dist_desc_str
,
ips
,
self
.
instance
.
get_node_cnt
(),
self
.
instance
.
get_node_cnt
(),
self
.
instance
.
_rankid
)
self
.
instance
.
_rankid
)
self
.
instance
.
barrier_all
()
#wait all worker start
self
.
instance
.
barrier_all
()
#wait all worker start
...
...
python/paddle/fluid/device_worker.py
0 → 100644
浏览文件 @
c1650120
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
DeviceWorker
(
object
):
def
__init__
(
self
):
pass
def
gen_worker_desc
(
self
,
trainer_desc
,
fleet_desc
):
pass
class
Hogwild
(
DeviceWorker
):
def
__init__
(
self
):
super
(
Hogwild
,
self
).
__init__
()
def
gen_worker_desc
(
self
,
trainer_desc
,
fleet_desc
):
trainer_desc
.
device_worker_name
=
"HogwildWorker"
class
Downpour
(
DeviceWorker
):
def
__init__
(
self
):
super
(
Downpour
,
self
).
__init__
()
def
gen_worker_desc
(
self
,
trainer_desc
,
fleet_desc
):
trainer_desc
.
device_worker_name
=
"DownpourWorker"
pull_thread
=
trainer_desc
.
pull_dense_param
pull_thread
.
device_num
=
trainer_desc
.
thread_num
dense_table
=
pull_thread
.
dense_table
.
add
()
dense_table
.
dense_value_name
.
extend
(
fleet_desc
.
trainer_param
.
dense_table
[
0
].
dense_variable_name
)
dense_table
.
table_id
=
\
fleet_desc
.
trainer_param
.
dense_table
[
0
].
table_id
downpour
=
trainer_desc
.
downpour_param
sparse_table
=
downpour
.
sparse_table
.
add
()
sparse_table
.
table_id
=
\
fleet_desc
.
trainer_param
.
sparse_table
[
0
].
table_id
sparse_table
.
sparse_key_name
.
extend
(
fleet_desc
.
trainer_param
.
sparse_table
[
0
].
slot_key
)
sparse_table
.
sparse_value_name
.
extend
(
fleet_desc
.
trainer_param
.
sparse_table
[
0
].
slot_value
)
sparse_table
.
sparse_grad_name
.
extend
(
fleet_desc
.
trainer_param
.
sparse_table
[
0
].
slot_gradient
)
sparse_table
.
emb_dim
=
fleet_desc
.
server_param
.
downpour_server_param
.
downpour_table_param
[
0
].
accessor
.
fea_dim
-
2
sparse_table
.
fea_dim
=
sparse_table
.
emb_dim
+
2
sparse_table
.
label_var_name
=
"click"
dense_table
=
downpour
.
dense_table
.
add
()
dense_table
.
table_id
=
\
fleet_desc
.
trainer_param
.
dense_table
[
0
].
table_id
dense_table
.
dense_value_name
.
extend
(
fleet_desc
.
trainer_param
.
dense_table
[
0
].
dense_variable_name
)
dense_table
.
dense_grad_name
.
extend
(
fleet_desc
.
trainer_param
.
dense_table
[
0
].
dense_gradient_variable_name
)
downpour
.
skip_ops
.
extend
(
fleet_desc
.
trainer_param
.
skip_op
)
class
DeviceWorkerFactory
(
object
):
def
create_device_worker
(
self
,
worker_type
):
classname
=
worker_type
.
capitalize
()
print
(
"------------"
)
print
(
classname
)
return
globals
()[
classname
]()
python/paddle/fluid/trainer_desc.py
浏览文件 @
c1650120
...
@@ -13,7 +13,8 @@
...
@@ -13,7 +13,8 @@
# limitations under the License.
# limitations under the License.
from
paddle.fluid.proto
import
trainer_desc_pb2
from
paddle.fluid.proto
import
trainer_desc_pb2
import
ps_pb2
as
pslib
from
distributed
import
ps_pb2
as
ps_pb2
from
device_worker
import
DeviceWorkerFactory
from
google.protobuf
import
text_format
from
google.protobuf
import
text_format
__all__
=
[
'TrainerDesc'
,
'MultiTrainer'
,
'DistMultiTrainer'
]
__all__
=
[
'TrainerDesc'
,
'MultiTrainer'
,
'DistMultiTrainer'
]
...
@@ -28,16 +29,22 @@ class TrainerDesc(object):
...
@@ -28,16 +29,22 @@ class TrainerDesc(object):
text_format.Parse(f.read(), self.proto_desc)
text_format.Parse(f.read(), self.proto_desc)
'''
'''
self
.
proto_desc
=
trainer_desc_pb2
.
TrainerDesc
()
self
.
proto_desc
=
trainer_desc_pb2
.
TrainerDesc
()
self
.
proto_desc
.
thread_num
=
12
def
set_thread
(
self
,
thread_num
):
def
set_thread
(
self
,
thread_num
):
self
.
proto_desc
.
thread_num
=
thread_num
self
.
proto_desc
.
thread_num
=
thread_num
def
set_filelist
(
self
,
filelist
):
def
set_filelist
(
self
,
filelist
):
self
.
proto_desc
.
filelist
.
extend
(
filelist
)
self
.
proto_desc
.
filelist
.
extend
(
filelist
)
self
.
proto_desc
.
thread_num
=
min
(
len
(
filelist
),
self
.
proto_desc
.
thread_num
)
def
set_data_feed
(
self
,
datafeed
):
def
set_data_feed
(
self
,
datafeed
):
self
.
proto_desc
.
data_desc
.
CopyFrom
(
datafeed
.
proto_desc
)
self
.
proto_desc
.
data_desc
.
CopyFrom
(
datafeed
.
proto_desc
)
def
gen_trainer_desc
(
self
,
dataset
=
None
,
fleet_desc
=
None
,
worker
=
None
):
pass
def
_desc
(
self
):
def
_desc
(
self
):
return
text_format
.
MessageToString
(
self
.
proto_desc
)
return
text_format
.
MessageToString
(
self
.
proto_desc
)
...
@@ -52,41 +59,20 @@ class MultiTrainer(TrainerDesc):
...
@@ -52,41 +59,20 @@ class MultiTrainer(TrainerDesc):
raise
ValueError
(
'ValueError: DeviceWorker %s '
raise
ValueError
(
'ValueError: DeviceWorker %s '
'is not supported in MultiTrainer'
%
worker
)
'is not supported in MultiTrainer'
%
worker
)
def
gen_trainer_desc
(
self
,
dataset
=
None
,
fleet_desc
=
None
,
worker
=
"Hogwild"
):
super
(
MultiTrainer
,
self
).
gen_trainer_desc
(
fleet_desc
,
worker
)
class
DistMultiTrainer
(
TrainerDesc
):
class
DistMultiTrainer
(
TrainerDesc
):
def
__init__
(
self
,
dataset
=
None
,
worker
=
'Downpour'
,
fleet_desc
=
None
):
def
__init__
(
self
):
super
(
DistMultiTrainer
,
self
).
__init__
()
super
(
DistMultiTrainer
,
self
).
__init__
()
if
worker
==
"Downpour"
:
pass
self
.
proto_desc
.
device_worker_name
=
worker
+
"Worker"
self
.
proto_desc
.
class_name
=
"DistMultiTrainer"
self
.
proto_desc
.
data_feed
.
CopyFrom
(
dataset
)
downpour
=
self
.
proto_desc
.
downpour_param
.
add
()
# sparse table should specify:
sparse_table
=
downpour
.
sparse_table
.
add
()
sparse_table
.
table_id
=
\
fleet_desc
.
trainer_param
.
sparse_table
.
table_id
sparse_table
.
sparse_key_name
.
CopyFrom
(
fleet_desc
.
trainer_param
()
.
sparse_table
().
slot_key
())
sparse_table
.
sparse_value_name
.
CopyFrom
(
fleet_desc
.
trainer_param
(
).
sparse_table
().
slot_value
())
sparse_table
.
sparse_grad_name
.
CopyFrom
(
fleet_desc
.
trainer_param
(
).
sparse_table
().
slot_gradient
())
sparse_table
.
emb_dim
=
fleet_desc
.
server_param
.
downpour_server_param
.
downpour_table_param
.
accessor
.
fea_dim
-
2
sparse_table
.
fea_dim
=
downpour
.
emb_dim
+
2
sparse_table
.
label_var_name
=
"click"
# dense table should specify:
def
gen_trainer_desc
(
self
,
dataset
=
None
,
fleet_desc
=
None
,
dense_table
=
downpour
.
dense_table
.
add
()
worker
=
"Downpour"
):
dense_table
.
table_id
=
\
super
(
DistMultiTrainer
,
self
).
gen_trainer_desc
(
fleet_desc
,
worker
)
fleet_desc
.
trainer_param
.
dense_table
.
table_id
self
.
proto_desc
.
class_name
=
"DistMultiTrainer"
# dense_value_name
self
.
proto_desc
.
data_desc
.
CopyFrom
(
dataset
.
proto_desc
)
dense_table
.
dense_value_name
.
CopyFrom
(
fleet_desc
.
trainer_param
(
worker_builder
=
DeviceWorkerFactory
()
).
dense_table
().
dense_variable_name
)
device_worker
=
worker_builder
.
create_device_worker
(
"Downpour"
)
# dense_grad_name
device_worker
.
gen_worker_desc
(
self
.
proto_desc
,
fleet_desc
)
dense_table
.
dense_grad_name
.
CopyFrom
(
fleet_desc
.
trainer_param
(
).
dense_table
().
dense_gradient_name
)
downpour
.
skipped_ops
.
extend
(
fleet_desc
.
trainer_param
.
skip_op
)
print
(
str
(
self
.
proto_desc
))
else
:
raise
ValueError
(
'ValueError: DeviceWorker %s '
'is not supported in DistMultiTrainer'
%
worker
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录