Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
d87ba58c
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d87ba58c
编写于
3月 26, 2019
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine document of python API, make device_worker and trainer's API private
test=develop
上级
5687f234
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
112 addition
and
90 deletion
+112
-90
paddle/fluid/framework/downpour_worker.cc
paddle/fluid/framework/downpour_worker.cc
+65
-53
python/paddle/fluid/async_executor.py
python/paddle/fluid/async_executor.py
+3
-5
python/paddle/fluid/device_worker.py
python/paddle/fluid/device_worker.py
+19
-7
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+12
-12
python/paddle/fluid/trainer_desc.py
python/paddle/fluid/trainer_desc.py
+12
-12
python/paddle/fluid/trainer_factory.py
python/paddle/fluid/trainer_factory.py
+1
-1
未找到文件。
paddle/fluid/framework/downpour_worker.cc
浏览文件 @
d87ba58c
...
@@ -279,11 +279,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
...
@@ -279,11 +279,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
total_time
+=
timeline
.
ElapsedSec
();
total_time
+=
timeline
.
ElapsedSec
();
VLOG
(
3
)
<<
"push sparse and dense gradient done."
;
VLOG
(
3
)
<<
"push sparse and dense gradient done."
;
int32_t
tmp_push_dense_wait_times
=
-
1
;
int32_t
tmp_push_dense_wait_times
=
-
1
;
int32_t
tmp_push_sparse_wait_times
=
-
1
;
static
uint32_t
push_dense_wait_times
=
static
uint32_t
push_dense_wait_times
=
static_cast
<
uint32_t
>
(
tmp_push_dense_wait_times
);
static_cast
<
uint32_t
>
(
tmp_push_dense_wait_times
);
static
uint32_t
push_sparse_wait_times
=
static_cast
<
uint32_t
>
(
tmp_push_sparse_wait_times
);
if
(
push_dense_status_
.
size
()
>=
push_dense_wait_times
)
{
if
(
push_dense_status_
.
size
()
>=
push_dense_wait_times
)
{
for
(
auto
&
t
:
push_dense_status_
)
{
for
(
auto
&
t
:
push_dense_status_
)
{
t
.
wait
();
t
.
wait
();
...
@@ -297,6 +294,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
...
@@ -297,6 +294,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
}
if
(
need_to_push_sparse_
)
{
if
(
need_to_push_sparse_
)
{
int32_t
tmp_push_sparse_wait_times
=
-
1
;
static
uint32_t
push_sparse_wait_times
=
static_cast
<
uint32_t
>
(
tmp_push_sparse_wait_times
);
if
(
push_sparse_status_
.
size
()
>=
push_sparse_wait_times
)
{
if
(
push_sparse_status_
.
size
()
>=
push_sparse_wait_times
)
{
for
(
auto
&
t
:
push_sparse_status_
)
{
for
(
auto
&
t
:
push_sparse_status_
)
{
t
.
wait
();
t
.
wait
();
...
@@ -311,6 +311,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
...
@@ -311,6 +311,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
VLOG
(
3
)
<<
"going to increase thread version"
;
VLOG
(
3
)
<<
"going to increase thread version"
;
VLOG
(
3
)
<<
"push dense table id size: "
VLOG
(
3
)
<<
"push dense table id size: "
<<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
<<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
}
if
(
need_to_push_dense_
)
{
for
(
size_t
i
=
0
;
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
++
i
)
{
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
uint64_t
tid
=
static_cast
<
uint64_t
>
(
...
@@ -381,69 +384,78 @@ void DownpourWorker::TrainFiles() {
...
@@ -381,69 +384,78 @@ void DownpourWorker::TrainFiles() {
}
}
}
}
// push gradients here
if
(
need_to_push_sparse_
)
{
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
push_sparse_table_id_size
();
// push gradients here
++
i
)
{
for
(
size_t
i
=
0
;
uint64_t
tid
=
static_cast
<
uint64_t
>
(
i
<
param_
.
program_config
(
0
).
push_sparse_table_id_size
();
++
i
)
{
param_
.
program_config
(
0
).
push_sparse_table_id
(
i
));
uint64_t
tid
=
static_cast
<
uint64_t
>
(
TableParameter
table
;
param_
.
program_config
(
0
).
push_sparse_table_id
(
i
));
for
(
auto
i
:
param_
.
sparse_table
())
{
TableParameter
table
;
if
(
i
.
table_id
()
==
tid
)
{
for
(
auto
i
:
param_
.
sparse_table
())
{
table
=
i
;
if
(
i
.
table_id
()
==
tid
)
{
break
;
table
=
i
;
break
;
}
}
}
fleet_ptr_
->
PushSparseVarsWithLabelAsync
(
*
thread_scope_
,
tid
,
features_
[
tid
],
feature_labels_
[
tid
],
sparse_key_names_
[
tid
],
sparse_grad_names_
[
tid
],
table
.
emb_dim
(),
&
feature_grads_
[
tid
],
&
push_sparse_status_
);
}
}
fleet_ptr_
->
PushSparseVarsWithLabelAsync
(
*
thread_scope_
,
tid
,
features_
[
tid
],
feature_labels_
[
tid
],
sparse_key_names_
[
tid
],
sparse_grad_names_
[
tid
],
table
.
emb_dim
(),
&
feature_grads_
[
tid
],
&
push_sparse_status_
);
}
}
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
if
(
need_to_push_dense_
)
{
++
i
)
{
for
(
size_t
i
=
0
;
uint64_t
tid
=
static_cast
<
uint64_t
>
(
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
++
i
)
{
param_
.
program_config
(
0
).
push_dense_table_id
(
i
));
uint64_t
tid
=
static_cast
<
uint64_t
>
(
fleet_ptr_
->
PushDenseVarsAsync
(
param_
.
program_config
(
0
).
push_dense_table_id
(
i
));
*
thread_scope_
,
tid
,
dense_grad_names_
[
tid
],
&
push_sparse_status_
);
fleet_ptr_
->
PushDenseVarsAsync
(
}
*
thread_scope_
,
tid
,
dense_grad_names_
[
tid
],
&
push_sparse_status_
);
}
VLOG
(
3
)
<<
"push dense gradient done."
;
// the following code should be more precise and clean
// TODO(guru4elephant)
int32_t
tmp_push_dense_wait_times
=
-
1
;
static
uint32_t
push_dense_wait_times
=
static_cast
<
uint32_t
>
(
tmp_push_dense_wait_times
);
VLOG
(
3
)
<<
"push sparse and dense gradient done."
;
if
(
push_dense_status_
.
size
()
>=
push_dense_wait_times
)
{
// the following code should be more precise and clean
for
(
auto
&
t
:
push_dense_status_
)
{
// TODO(guru4elephant)
t
.
wait
();
int32_t
tmp_push_dense_wait_times
=
-
1
;
}
int32_t
tmp_push_sparse_wait_times
=
-
1
;
push_dense_status_
.
resize
(
0
);
static
uint32_t
push_dense_wait_times
=
static_cast
<
uint32_t
>
(
tmp_push_dense_wait_times
);
static
uint32_t
push_sparse_wait_times
=
static_cast
<
uint32_t
>
(
tmp_push_sparse_wait_times
);
if
(
push_dense_status_
.
size
()
>=
push_dense_wait_times
)
{
for
(
auto
&
t
:
push_dense_status_
)
{
t
.
wait
();
}
}
push_dense_status_
.
resize
(
0
);
}
if
(
tmp_push_dense_wait_times
==
-
1
)
{
if
(
tmp_push_dense_wait_times
==
-
1
)
{
push_dense_status_
.
resize
(
0
);
push_dense_status_
.
resize
(
0
);
}
}
}
if
(
push_sparse_status_
.
size
()
>=
push_sparse_wait_times
)
{
if
(
need_to_push_sparse_
)
{
for
(
auto
&
t
:
push_sparse_status_
)
{
VLOG
(
3
)
<<
"push sparse gradient done."
;
t
.
wait
();
int32_t
tmp_push_sparse_wait_times
=
-
1
;
static
uint32_t
push_sparse_wait_times
=
static_cast
<
uint32_t
>
(
tmp_push_sparse_wait_times
);
if
(
push_sparse_status_
.
size
()
>=
push_sparse_wait_times
)
{
for
(
auto
&
t
:
push_sparse_status_
)
{
t
.
wait
();
}
push_sparse_status_
.
resize
(
0
);
}
}
push_sparse_status_
.
resize
(
0
);
}
if
(
tmp_push_sparse_wait_times
==
-
1
)
{
if
(
tmp_push_sparse_wait_times
==
-
1
)
{
push_sparse_status_
.
resize
(
0
);
push_sparse_status_
.
resize
(
0
);
}
}
}
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
if
(
need_to_push_dense_
)
{
++
i
)
{
for
(
size_t
i
=
0
;
uint64_t
tid
=
static_cast
<
uint64_t
>
(
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
++
i
)
{
param_
.
program_config
(
0
).
push_dense_table_id
(
i
));
uint64_t
tid
=
static_cast
<
uint64_t
>
(
pull_dense_worker_
->
IncreaseThreadVersion
(
thread_id_
,
tid
);
param_
.
program_config
(
0
).
push_dense_table_id
(
i
));
pull_dense_worker_
->
IncreaseThreadVersion
(
thread_id_
,
tid
);
}
}
}
PrintFetchVars
();
PrintFetchVars
();
...
...
python/paddle/fluid/async_executor.py
浏览文件 @
d87ba58c
...
@@ -154,10 +154,8 @@ class AsyncExecutor(object):
...
@@ -154,10 +154,8 @@ class AsyncExecutor(object):
with
open
(
"trainer_desc.proto"
,
"w"
)
as
fout
:
with
open
(
"trainer_desc.proto"
,
"w"
)
as
fout
:
fout
.
write
(
trainer
.
_desc
())
fout
.
write
(
trainer
.
_desc
())
# define a trainer and a device_worker here
# define a trainer and a device_worker here
self
.
executor
.
run_from_files
(
program_desc
,
self
.
executor
.
run_from_files
(
program_desc
,
trainer
.
_desc
(),
debug
)
trainer
.
_desc
(),
debug
)
'''
def
run
(
self
,
def
run
(
self
,
program
,
program
,
data_feed
,
data_feed
,
...
@@ -228,8 +226,8 @@ class AsyncExecutor(object):
...
@@ -228,8 +226,8 @@ class AsyncExecutor(object):
self
.
executor
.
run_from_files
(
program_desc
,
self
.
executor
.
run_from_files
(
program_desc
,
data_feed
.
desc
(),
filelist
,
thread_num
,
data_feed
.
desc
(),
filelist
,
thread_num
,
fetch_var_names, mode, debug,
str(id(program_desc)))
fetch_var_names
,
mode
,
debug
,
'''
str
(
id
(
program_desc
)))
def
download_data
(
self
,
def
download_data
(
self
,
afs_path
,
afs_path
,
...
...
python/paddle/fluid/device_worker.py
浏览文件 @
d87ba58c
...
@@ -19,7 +19,10 @@ __all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
...
@@ -19,7 +19,10 @@ __all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
class
DeviceWorker
(
object
):
class
DeviceWorker
(
object
):
"""
"""
DeviceWorker is a abstract class, which generates worker desc.
DeviceWorker is a abstract class, which generates worker desc.
This class is an inner class that we do computation logics within
the implementation. For example, execution of a program or a graph.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
"""
"""
Init.
Init.
...
@@ -27,10 +30,16 @@ class DeviceWorker(object):
...
@@ -27,10 +30,16 @@ class DeviceWorker(object):
self
.
program_
=
None
self
.
program_
=
None
self
.
infer_
=
None
self
.
infer_
=
None
def
set_infer
(
self
,
infer
=
False
):
def
_set_infer
(
self
,
infer
=
False
):
"""
set inference flag for current device worker
Args:
infer(bool): whether to do inference
"""
self
.
infer_
=
infer
self
.
infer_
=
infer
def
set_fleet_desc
(
self
,
fleet_desc
):
def
_
set_fleet_desc
(
self
,
fleet_desc
):
"""
"""
Set fleet desc.
Set fleet desc.
...
@@ -39,7 +48,7 @@ class DeviceWorker(object):
...
@@ -39,7 +48,7 @@ class DeviceWorker(object):
"""
"""
self
.
fleet_desc_
=
fleet_desc
self
.
fleet_desc_
=
fleet_desc
def
set_program
(
self
,
program
):
def
_
set_program
(
self
,
program
):
"""
"""
Set program.
Set program.
...
@@ -48,7 +57,7 @@ class DeviceWorker(object):
...
@@ -48,7 +57,7 @@ class DeviceWorker(object):
"""
"""
self
.
program_
=
program
self
.
program_
=
program
def
gen_worker_desc
(
self
,
trainer_desc
):
def
_
gen_worker_desc
(
self
,
trainer_desc
):
"""
"""
Generator worker desc.
Generator worker desc.
...
@@ -65,13 +74,14 @@ class Hogwild(DeviceWorker):
...
@@ -65,13 +74,14 @@ class Hogwild(DeviceWorker):
Hogwild is a kind of SGD algorithm.
Hogwild is a kind of SGD algorithm.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
"""
"""
Init.
Init.
"""
"""
super
(
Hogwild
,
self
).
__init__
()
super
(
Hogwild
,
self
).
__init__
()
def
gen_worker_desc
(
self
,
trainer_desc
):
def
_
gen_worker_desc
(
self
,
trainer_desc
):
"""
"""
Generator worker desc, which device worker is HogwildWorker.
Generator worker desc, which device worker is HogwildWorker.
...
@@ -85,13 +95,15 @@ class DownpourSGD(DeviceWorker):
...
@@ -85,13 +95,15 @@ class DownpourSGD(DeviceWorker):
"""
"""
DownpourSGD is a kind of distributed SGD algorithm.
DownpourSGD is a kind of distributed SGD algorithm.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
"""
"""
Init.
Init.
initialize downpourSGD device worker
"""
"""
super
(
DownpourSGD
,
self
).
__init__
()
super
(
DownpourSGD
,
self
).
__init__
()
def
gen_worker_desc
(
self
,
trainer_desc
):
def
_
gen_worker_desc
(
self
,
trainer_desc
):
"""
"""
Generator worker desc, which device worker is DownpourWorker.
Generator worker desc, which device worker is DownpourWorker.
...
@@ -162,6 +174,6 @@ class DownpourSGD(DeviceWorker):
...
@@ -162,6 +174,6 @@ class DownpourSGD(DeviceWorker):
class
DeviceWorkerFactory
(
object
):
class
DeviceWorkerFactory
(
object
):
def
create_device_worker
(
self
,
worker_type
):
def
_
create_device_worker
(
self
,
worker_type
):
classname
=
worker_type
.
capitalize
()
classname
=
worker_type
.
capitalize
()
return
globals
()[
classname
]()
return
globals
()[
classname
]()
python/paddle/fluid/executor.py
浏览文件 @
d87ba58c
...
@@ -637,23 +637,23 @@ class Executor(object):
...
@@ -637,23 +637,23 @@ class Executor(object):
assert
len
(
fetch_list
)
==
len
(
fetch_info
)
assert
len
(
fetch_list
)
==
len
(
fetch_info
)
compiled
=
isinstance
(
program
,
compiler
.
CompiledProgram
)
compiled
=
isinstance
(
program
,
compiler
.
CompiledProgram
)
if
not
compiled
:
if
not
compiled
:
trainer
=
TrainerFactory
().
create_trainer
(
program
.
_fleet_opt
)
trainer
=
TrainerFactory
().
_
create_trainer
(
program
.
_fleet_opt
)
trainer
.
set_program
(
program
)
trainer
.
_
set_program
(
program
)
else
:
else
:
trainer
=
TrainerFactory
().
create_trainer
(
trainer
=
TrainerFactory
().
_
create_trainer
(
program
.
program
.
_fleet_opt
)
program
.
program
.
_fleet_opt
)
trainer
.
set_program
(
program
.
program
)
trainer
.
_
set_program
(
program
.
program
)
if
thread
<=
0
:
if
thread
<=
0
:
if
dataset
.
thread_num
<=
0
:
if
dataset
.
thread_num
<=
0
:
raise
RuntimeError
(
raise
RuntimeError
(
"You should set thread num first, either in Dataset"
"You should set thread num first, either in Dataset"
"or in Executor.train_from_dataset"
)
"or in Executor.train_from_dataset"
)
else
:
else
:
trainer
.
set_thread
(
dataset
.
thread_num
)
trainer
.
_
set_thread
(
dataset
.
thread_num
)
else
:
else
:
trainer
.
set_thread
(
thread
)
trainer
.
_
set_thread
(
thread
)
trainer
.
set_debug
(
debug
)
trainer
.
_
set_debug
(
debug
)
trainer
.
set_fetch_var_and_info
(
fetch_list
,
fetch_info
,
print_period
)
trainer
.
_
set_fetch_var_and_info
(
fetch_list
,
fetch_info
,
print_period
)
return
trainer
return
trainer
def
infer_from_dataset
(
self
,
def
infer_from_dataset
(
self
,
...
@@ -679,7 +679,7 @@ class Executor(object):
...
@@ -679,7 +679,7 @@ class Executor(object):
for each run. default is global_scope
for each run. default is global_scope
thread(int): number of thread a user wants to run in this function. The actual number
thread(int): number of thread a user wants to run in this function. The actual number
of thread will be min(Dataset.thread_num, thread)
of thread will be min(Dataset.thread_num, thread)
debug(bool): whether a user wants to run
train
_from_dataset
debug(bool): whether a user wants to run
infer
_from_dataset
fetch_list(Variable List): fetch variable list, each variable
fetch_list(Variable List): fetch variable list, each variable
will be printed during training
will be printed during training
fetch_info(String List): print information for each variable
fetch_info(String List): print information for each variable
...
@@ -711,8 +711,8 @@ class Executor(object):
...
@@ -711,8 +711,8 @@ class Executor(object):
fetch_list
=
fetch_list
,
fetch_list
=
fetch_list
,
fetch_info
=
fetch_info
,
fetch_info
=
fetch_info
,
print_period
=
print_period
)
print_period
=
print_period
)
trainer
.
gen_trainer_desc
()
trainer
.
_
gen_trainer_desc
()
trainer
.
set_infer
(
True
)
trainer
.
_
set_infer
(
True
)
dataset
.
_prepare_to_run
()
dataset
.
_prepare_to_run
()
if
debug
:
if
debug
:
self
.
_dump_debug_info
(
program
=
program
,
trainer
=
trainer
)
self
.
_dump_debug_info
(
program
=
program
,
trainer
=
trainer
)
...
@@ -784,7 +784,7 @@ class Executor(object):
...
@@ -784,7 +784,7 @@ class Executor(object):
fetch_list
=
fetch_list
,
fetch_list
=
fetch_list
,
fetch_info
=
fetch_info
,
fetch_info
=
fetch_info
,
print_period
=
print_period
)
print_period
=
print_period
)
trainer
.
gen_trainer_desc
()
trainer
.
_
gen_trainer_desc
()
dataset
.
_prepare_to_run
()
dataset
.
_prepare_to_run
()
if
debug
:
if
debug
:
self
.
_dump_debug_info
(
program
=
program
,
trainer
=
trainer
)
self
.
_dump_debug_info
(
program
=
program
,
trainer
=
trainer
)
...
...
python/paddle/fluid/trainer_desc.py
浏览文件 @
d87ba58c
...
@@ -37,32 +37,32 @@ class TrainerDesc(object):
...
@@ -37,32 +37,32 @@ class TrainerDesc(object):
self
.
program_
=
None
self
.
program_
=
None
self
.
infer_
=
False
self
.
infer_
=
False
def
set_fetch_var_and_info
(
self
,
fetch_vars
,
fetch_info
,
print_period
):
def
_
set_fetch_var_and_info
(
self
,
fetch_vars
,
fetch_info
,
print_period
):
for
i
,
v
in
enumerate
(
fetch_vars
):
for
i
,
v
in
enumerate
(
fetch_vars
):
self
.
proto_desc
.
fetch_config
.
fetch_var_names
.
extend
([
v
.
name
])
self
.
proto_desc
.
fetch_config
.
fetch_var_names
.
extend
([
v
.
name
])
self
.
proto_desc
.
fetch_config
.
fetch_var_str_format
.
extend
(
self
.
proto_desc
.
fetch_config
.
fetch_var_str_format
.
extend
(
[
fetch_info
[
i
]])
[
fetch_info
[
i
]])
self
.
proto_desc
.
fetch_config
.
print_period
=
print_period
self
.
proto_desc
.
fetch_config
.
print_period
=
print_period
def
set_debug
(
self
,
debug
):
def
_
set_debug
(
self
,
debug
):
self
.
proto_desc
.
debug
=
debug
self
.
proto_desc
.
debug
=
debug
def
set_thread
(
self
,
thread_num
):
def
_
set_thread
(
self
,
thread_num
):
self
.
proto_desc
.
thread_num
=
thread_num
self
.
proto_desc
.
thread_num
=
thread_num
def
set_device_worker
(
self
,
device_worker
):
def
_
set_device_worker
(
self
,
device_worker
):
self
.
device_worker_
=
device_worker
self
.
device_worker_
=
device_worker
def
set_infer
(
self
,
infer
):
def
_
set_infer
(
self
,
infer
):
self
.
infer_
=
infer
self
.
infer_
=
infer
def
set_fleet_desc
(
self
,
fleet_desc
):
def
_
set_fleet_desc
(
self
,
fleet_desc
):
self
.
fleet_desc_
=
fleet_desc
self
.
fleet_desc_
=
fleet_desc
def
gen_trainer_desc
(
self
):
def
_
gen_trainer_desc
(
self
):
pass
pass
def
set_program
(
self
,
program
):
def
_
set_program
(
self
,
program
):
self
.
program_
=
program
self
.
program_
=
program
def
_desc
(
self
):
def
_desc
(
self
):
...
@@ -74,11 +74,11 @@ class MultiTrainer(TrainerDesc):
...
@@ -74,11 +74,11 @@ class MultiTrainer(TrainerDesc):
super
(
MultiTrainer
,
self
).
__init__
()
super
(
MultiTrainer
,
self
).
__init__
()
pass
pass
def
set_program
(
self
,
program
):
def
_
set_program
(
self
,
program
):
super
(
MultiTrainer
,
self
).
set_program
(
program
)
super
(
MultiTrainer
,
self
).
set_program
(
program
)
self
.
program_
=
program
self
.
program_
=
program
def
gen_trainer_desc
(
self
):
def
_
gen_trainer_desc
(
self
):
super
(
MultiTrainer
,
self
).
gen_trainer_desc
()
super
(
MultiTrainer
,
self
).
gen_trainer_desc
()
self
.
proto_desc
.
class_name
=
"MultiTrainer"
self
.
proto_desc
.
class_name
=
"MultiTrainer"
self
.
device_worker_
.
set_infer
(
self
.
infer_
)
self
.
device_worker_
.
set_infer
(
self
.
infer_
)
...
@@ -90,11 +90,11 @@ class DistMultiTrainer(TrainerDesc):
...
@@ -90,11 +90,11 @@ class DistMultiTrainer(TrainerDesc):
super
(
DistMultiTrainer
,
self
).
__init__
()
super
(
DistMultiTrainer
,
self
).
__init__
()
pass
pass
def
set_program
(
self
,
program
):
def
_
set_program
(
self
,
program
):
super
(
DistMultiTrainer
,
self
).
set_program
(
program
)
super
(
DistMultiTrainer
,
self
).
set_program
(
program
)
self
.
program_
=
program
self
.
program_
=
program
def
gen_trainer_desc
(
self
):
def
_
gen_trainer_desc
(
self
):
super
(
DistMultiTrainer
,
self
).
gen_trainer_desc
()
super
(
DistMultiTrainer
,
self
).
gen_trainer_desc
()
self
.
proto_desc
.
class_name
=
"DistMultiTrainer"
self
.
proto_desc
.
class_name
=
"DistMultiTrainer"
if
self
.
program_
==
None
:
if
self
.
program_
==
None
:
...
...
python/paddle/fluid/trainer_factory.py
浏览文件 @
d87ba58c
...
@@ -22,7 +22,7 @@ class TrainerFactory(object):
...
@@ -22,7 +22,7 @@ class TrainerFactory(object):
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
def
create_trainer
(
self
,
opt_info
=
None
):
def
_
create_trainer
(
self
,
opt_info
=
None
):
trainer
=
None
trainer
=
None
device_worker
=
None
device_worker
=
None
if
opt_info
==
None
:
if
opt_info
==
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录