Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
39014b9f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
39014b9f
编写于
2月 02, 2019
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix class register problem
上级
f0dd1201
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
50 addition
and
62 deletion
+50
-62
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+2
-20
paddle/fluid/framework/device_worker_factory.cc
paddle/fluid/framework/device_worker_factory.cc
+21
-0
paddle/fluid/framework/device_worker_factory.h
paddle/fluid/framework/device_worker_factory.h
+0
-19
paddle/fluid/framework/dist_multi_trainer.cc
paddle/fluid/framework/dist_multi_trainer.cc
+2
-2
paddle/fluid/framework/downpour_worker.cc
paddle/fluid/framework/downpour_worker.cc
+5
-1
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+0
-1
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+0
-2
paddle/fluid/framework/trainer_factory.cc
paddle/fluid/framework/trainer_factory.cc
+20
-0
paddle/fluid/framework/trainer_factory.h
paddle/fluid/framework/trainer_factory.h
+0
-17
未找到文件。
paddle/fluid/framework/async_executor.cc
浏览文件 @
39014b9f
...
...
@@ -64,7 +64,6 @@ void AsyncExecutor::InitModel() {}
void
AsyncExecutor
::
SaveModel
(
const
std
::
string
&
path
)
{}
void
AsyncExecutor
::
RunFromFile
(
const
ProgramDesc
&
main_program
,
<<<<<<<
HEAD
const
std
::
string
&
data_feed_desc_str
,
const
std
::
vector
<
std
::
string
>&
filelist
,
const
int
thread_num
,
...
...
@@ -153,25 +152,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
_pull_dense_thread
->
stop
();
}
#endif
=======
const
std
::
string
&
trainer_desc_str
,
const
bool
debug
)
{
TrainerDesc
trainer_desc
;
google
::
protobuf
::
TextFormat
::
ParseFromString
(
trainer_desc_str
,
&
trainer_desc
);
std
::
shared_ptr
<
TrainerBase
>
trainer
;
trainer
=
TrainerFactory
::
CreateTrainer
(
trainer_desc
.
class_name
());
// initialize trainer
trainer
->
Initialize
(
trainer_desc
);
trainer
->
SetScope
(
root_scope_
);
trainer
->
SetDebug
(
debug
);
// prepare training environment and helper environment
trainer
->
InitTrainerEnv
(
main_program
,
place_
);
trainer
->
InitOtherEnv
(
main_program
);
// training and finalize training
trainer
->
Run
();
trainer
->
Finalize
();
>>>>>>>
add
dist_multi_trainer
for
distributed
training
,
add
trainer_factory
and
device_worker_factory
so
that
we
can
easily
extend
new
training
mode
,
add
pull
dense
worker
which
is
a
singleton
for
parameter
fetching
VLOG
(
3
)
<<
"start to run from files in async_executor"
;
VLOG
(
3
)
<<
"Drop current scope kids"
;
root_scope_
->
DropKids
();
return
;
...
...
paddle/fluid/framework/device_worker_factory.cc
浏览文件 @
39014b9f
...
...
@@ -20,6 +20,25 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
DeviceWorker
>
(
*
Createdevice_workerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
Createdevice_workerFunction
>
device_workerMap
;
device_workerMap
g_device_worker_map
;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
std
::
string
DeviceWorkerFactory
::
DeviceWorkerTypeList
()
{
std
::
string
device_worker_types
;
for
(
auto
iter
=
g_device_worker_map
.
begin
();
...
...
@@ -40,5 +59,7 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
return
g_device_worker_map
[
device_worker_class
]();
}
REGISTER_DEVICE_WORKER_CLASS
(
HogwildWorker
);
REGISTER_DEVICE_WORKER_CLASS
(
DownpourWorker
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/device_worker_factory.h
浏览文件 @
39014b9f
...
...
@@ -21,25 +21,6 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
DeviceWorker
>
(
*
Createdevice_workerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
Createdevice_workerFunction
>
device_workerMap
;
device_workerMap
g_device_worker_map
;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
class
DeviceWorkerFactory
{
public:
static
std
::
string
DeviceWorkerTypeList
();
...
...
paddle/fluid/framework/dist_multi_trainer.cc
浏览文件 @
39014b9f
...
...
@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -48,11 +47,13 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
fleet_ptr_
=
FleetWrapper
::
GetInstance
();
pull_dense_worker_
=
PullDenseWorker
::
GetInstance
();
pull_dense_worker_
->
Initialize
(
trainer_desc
);
VLOG
(
3
)
<<
"initialize pull dense worker"
;
}
void
DistMultiTrainer
::
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
{
pull_dense_worker_
->
SetRootScope
(
root_scope_
);
pull_dense_worker_
->
Start
();
VLOG
(
3
)
<<
"init other env done."
;
}
void
DistMultiTrainer
::
Finalize
()
{
...
...
@@ -62,6 +63,5 @@ void DistMultiTrainer::Finalize() {
pull_dense_worker_
->
Stop
();
}
REGISTER_TRAINER_CLASS
(
DistMultiTrainer
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/downpour_worker.cc
浏览文件 @
39014b9f
...
...
@@ -134,6 +134,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
}
void
DownpourWorker
::
TrainFiles
()
{
VLOG
(
3
)
<<
"Begin to train files"
;
platform
::
SetNumThreads
(
1
);
device_reader_
->
Start
();
int
batch_cnt
=
0
;
...
...
@@ -148,6 +149,7 @@ void DownpourWorker::TrainFiles() {
CollectLabelInfo
(
i
);
FillSparseValue
(
i
);
}
VLOG
(
3
)
<<
"fill sparse value for all sparse table done."
;
// do computation here
for
(
auto
&
op
:
ops_
)
{
...
...
@@ -179,6 +181,7 @@ void DownpourWorker::TrainFiles() {
*
thread_scope_
,
tid
,
dense_grad_names_
[
tid
],
&
push_sparse_status_
);
}
VLOG
(
3
)
<<
"push sparse and dense gradient done."
;
// the following code should be more precise and clean
// TODO(guru4elephant)
int32_t
tmp_push_dense_wait_times
=
-
1
;
...
...
@@ -210,16 +213,17 @@ void DownpourWorker::TrainFiles() {
push_sparse_status_
.
resize
(
0
);
}
/*
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
*/
thread_scope_
->
DropKids
();
++
batch_cnt
;
}
}
REGISTER_DEVICE_WORKER_CLASS
(
DownpourWorker
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
39014b9f
...
...
@@ -129,6 +129,5 @@ void HogwildWorker::TrainFiles() {
}
}
REGISTER_DEVICE_WORKER_CLASS
(
HogwildWorker
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
39014b9f
...
...
@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -66,6 +65,5 @@ void MultiTrainer::Finalize() {
}
}
REGISTER_TRAINER_CLASS
(
MultiTrainer
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/trainer_factory.cc
浏览文件 @
39014b9f
...
...
@@ -22,8 +22,24 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
TrainerBase
>
(
*
CreatetrainerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
CreatetrainerFunction
>
trainerMap
;
trainerMap
g_trainer_map
;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
std
::
string
TrainerFactory
::
TrainerTypeList
()
{
std
::
string
trainer_types
;
for
(
auto
iter
=
g_trainer_map
.
begin
();
iter
!=
g_trainer_map
.
end
();
++
iter
)
{
...
...
@@ -38,10 +54,14 @@ std::string TrainerFactory::TrainerTypeList() {
std
::
shared_ptr
<
TrainerBase
>
TrainerFactory
::
CreateTrainer
(
std
::
string
trainer_class
)
{
if
(
g_trainer_map
.
count
(
trainer_class
)
<
1
)
{
LOG
(
WARNING
)
<<
"Trainer class: "
<<
trainer_class
<<
" not defined"
;
LOG
(
WARNING
)
<<
TrainerTypeList
();
exit
(
-
1
);
}
return
g_trainer_map
[
trainer_class
]();
}
REGISTER_TRAINER_CLASS
(
MultiTrainer
);
REGISTER_TRAINER_CLASS
(
DistMultiTrainer
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/trainer_factory.h
浏览文件 @
39014b9f
...
...
@@ -20,23 +20,6 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
TrainerBase
>
(
*
CreatetrainerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
CreatetrainerFunction
>
trainerMap
;
extern
trainerMap
g_trainer_map
;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
class
TrainerFactory
{
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录