Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
39014b9f
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
39014b9f
编写于
2月 02, 2019
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix class register problem
上级
f0dd1201
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
50 addition
and
62 deletion
+50
-62
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+2
-20
paddle/fluid/framework/device_worker_factory.cc
paddle/fluid/framework/device_worker_factory.cc
+21
-0
paddle/fluid/framework/device_worker_factory.h
paddle/fluid/framework/device_worker_factory.h
+0
-19
paddle/fluid/framework/dist_multi_trainer.cc
paddle/fluid/framework/dist_multi_trainer.cc
+2
-2
paddle/fluid/framework/downpour_worker.cc
paddle/fluid/framework/downpour_worker.cc
+5
-1
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+0
-1
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+0
-2
paddle/fluid/framework/trainer_factory.cc
paddle/fluid/framework/trainer_factory.cc
+20
-0
paddle/fluid/framework/trainer_factory.h
paddle/fluid/framework/trainer_factory.h
+0
-17
未找到文件。
paddle/fluid/framework/async_executor.cc
浏览文件 @
39014b9f
...
...
@@ -64,7 +64,6 @@ void AsyncExecutor::InitModel() {}
void
AsyncExecutor
::
SaveModel
(
const
std
::
string
&
path
)
{}
void
AsyncExecutor
::
RunFromFile
(
const
ProgramDesc
&
main_program
,
<<<<<<<
HEAD
const
std
::
string
&
data_feed_desc_str
,
const
std
::
vector
<
std
::
string
>&
filelist
,
const
int
thread_num
,
...
...
@@ -153,25 +152,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
_pull_dense_thread
->
stop
();
}
#endif
=======
const
std
::
string
&
trainer_desc_str
,
const
bool
debug
)
{
TrainerDesc
trainer_desc
;
google
::
protobuf
::
TextFormat
::
ParseFromString
(
trainer_desc_str
,
&
trainer_desc
);
std
::
shared_ptr
<
TrainerBase
>
trainer
;
trainer
=
TrainerFactory
::
CreateTrainer
(
trainer_desc
.
class_name
());
// initialize trainer
trainer
->
Initialize
(
trainer_desc
);
trainer
->
SetScope
(
root_scope_
);
trainer
->
SetDebug
(
debug
);
// prepare training environment and helper environment
trainer
->
InitTrainerEnv
(
main_program
,
place_
);
trainer
->
InitOtherEnv
(
main_program
);
// training and finalize training
trainer
->
Run
();
trainer
->
Finalize
();
>>>>>>>
add
dist_multi_trainer
for
distributed
training
,
add
trainer_factory
and
device_worker_factory
so
that
we
can
easily
extend
new
training
mode
,
add
pull
dense
worker
which
is
a
singleton
for
parameter
fetching
VLOG
(
3
)
<<
"start to run from files in async_executor"
;
VLOG
(
3
)
<<
"Drop current scope kids"
;
root_scope_
->
DropKids
();
return
;
...
...
paddle/fluid/framework/device_worker_factory.cc
浏览文件 @
39014b9f
...
...
@@ -20,6 +20,25 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
DeviceWorker
>
(
*
Createdevice_workerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
Createdevice_workerFunction
>
device_workerMap
;
device_workerMap
g_device_worker_map
;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
std
::
string
DeviceWorkerFactory
::
DeviceWorkerTypeList
()
{
std
::
string
device_worker_types
;
for
(
auto
iter
=
g_device_worker_map
.
begin
();
...
...
@@ -40,5 +59,7 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
return
g_device_worker_map
[
device_worker_class
]();
}
REGISTER_DEVICE_WORKER_CLASS
(
HogwildWorker
);
REGISTER_DEVICE_WORKER_CLASS
(
DownpourWorker
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/device_worker_factory.h
浏览文件 @
39014b9f
...
...
@@ -21,25 +21,6 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
DeviceWorker
>
(
*
Createdevice_workerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
Createdevice_workerFunction
>
device_workerMap
;
device_workerMap
g_device_worker_map
;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
class
DeviceWorkerFactory
{
public:
static
std
::
string
DeviceWorkerTypeList
();
...
...
paddle/fluid/framework/dist_multi_trainer.cc
浏览文件 @
39014b9f
...
...
@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -48,11 +47,13 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
fleet_ptr_
=
FleetWrapper
::
GetInstance
();
pull_dense_worker_
=
PullDenseWorker
::
GetInstance
();
pull_dense_worker_
->
Initialize
(
trainer_desc
);
VLOG
(
3
)
<<
"initialize pull dense worker"
;
}
void
DistMultiTrainer
::
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
{
pull_dense_worker_
->
SetRootScope
(
root_scope_
);
pull_dense_worker_
->
Start
();
VLOG
(
3
)
<<
"init other env done."
;
}
void
DistMultiTrainer
::
Finalize
()
{
...
...
@@ -62,6 +63,5 @@ void DistMultiTrainer::Finalize() {
pull_dense_worker_
->
Stop
();
}
REGISTER_TRAINER_CLASS
(
DistMultiTrainer
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/downpour_worker.cc
浏览文件 @
39014b9f
...
...
@@ -134,6 +134,7 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
}
void
DownpourWorker
::
TrainFiles
()
{
VLOG
(
3
)
<<
"Begin to train files"
;
platform
::
SetNumThreads
(
1
);
device_reader_
->
Start
();
int
batch_cnt
=
0
;
...
...
@@ -148,6 +149,7 @@ void DownpourWorker::TrainFiles() {
CollectLabelInfo
(
i
);
FillSparseValue
(
i
);
}
VLOG
(
3
)
<<
"fill sparse value for all sparse table done."
;
// do computation here
for
(
auto
&
op
:
ops_
)
{
...
...
@@ -179,6 +181,7 @@ void DownpourWorker::TrainFiles() {
*
thread_scope_
,
tid
,
dense_grad_names_
[
tid
],
&
push_sparse_status_
);
}
VLOG
(
3
)
<<
"push sparse and dense gradient done."
;
// the following code should be more precise and clean
// TODO(guru4elephant)
int32_t
tmp_push_dense_wait_times
=
-
1
;
...
...
@@ -210,16 +213,17 @@ void DownpourWorker::TrainFiles() {
push_sparse_status_
.
resize
(
0
);
}
/*
for (size_t i = 0; i < param_.dense_table_size(); ++i) {
uint64_t tid = static_cast<uint64_t>(param_.dense_table(i).table_id());
pull_dense_worker_->IncreaseThreadVersion(thread_id_, tid);
}
*/
thread_scope_
->
DropKids
();
++
batch_cnt
;
}
}
REGISTER_DEVICE_WORKER_CLASS
(
DownpourWorker
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
39014b9f
...
...
@@ -129,6 +129,5 @@ void HogwildWorker::TrainFiles() {
}
}
REGISTER_DEVICE_WORKER_CLASS
(
HogwildWorker
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
39014b9f
...
...
@@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -66,6 +65,5 @@ void MultiTrainer::Finalize() {
}
}
REGISTER_TRAINER_CLASS
(
MultiTrainer
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/trainer_factory.cc
浏览文件 @
39014b9f
...
...
@@ -22,8 +22,24 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
TrainerBase
>
(
*
CreatetrainerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
CreatetrainerFunction
>
trainerMap
;
trainerMap
g_trainer_map
;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
std
::
string
TrainerFactory
::
TrainerTypeList
()
{
std
::
string
trainer_types
;
for
(
auto
iter
=
g_trainer_map
.
begin
();
iter
!=
g_trainer_map
.
end
();
++
iter
)
{
...
...
@@ -38,10 +54,14 @@ std::string TrainerFactory::TrainerTypeList() {
std
::
shared_ptr
<
TrainerBase
>
TrainerFactory
::
CreateTrainer
(
std
::
string
trainer_class
)
{
if
(
g_trainer_map
.
count
(
trainer_class
)
<
1
)
{
LOG
(
WARNING
)
<<
"Trainer class: "
<<
trainer_class
<<
" not defined"
;
LOG
(
WARNING
)
<<
TrainerTypeList
();
exit
(
-
1
);
}
return
g_trainer_map
[
trainer_class
]();
}
REGISTER_TRAINER_CLASS
(
MultiTrainer
);
REGISTER_TRAINER_CLASS
(
DistMultiTrainer
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/trainer_factory.h
浏览文件 @
39014b9f
...
...
@@ -20,23 +20,6 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
TrainerBase
>
(
*
CreatetrainerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
CreatetrainerFunction
>
trainerMap
;
extern
trainerMap
g_trainer_map
;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
class
TrainerFactory
{
public:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录