Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d87ba58c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d87ba58c
编写于
3月 26, 2019
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine document of python API, make device_worker and trainer's API private
test=develop
上级
5687f234
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
112 addition
and
90 deletion
+112
-90
paddle/fluid/framework/downpour_worker.cc
paddle/fluid/framework/downpour_worker.cc
+65
-53
python/paddle/fluid/async_executor.py
python/paddle/fluid/async_executor.py
+3
-5
python/paddle/fluid/device_worker.py
python/paddle/fluid/device_worker.py
+19
-7
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+12
-12
python/paddle/fluid/trainer_desc.py
python/paddle/fluid/trainer_desc.py
+12
-12
python/paddle/fluid/trainer_factory.py
python/paddle/fluid/trainer_factory.py
+1
-1
未找到文件。
paddle/fluid/framework/downpour_worker.cc
浏览文件 @
d87ba58c
...
@@ -279,11 +279,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
...
@@ -279,11 +279,8 @@ void DownpourWorker::TrainFilesWithProfiler() {
total_time
+=
timeline
.
ElapsedSec
();
total_time
+=
timeline
.
ElapsedSec
();
VLOG
(
3
)
<<
"push sparse and dense gradient done."
;
VLOG
(
3
)
<<
"push sparse and dense gradient done."
;
int32_t
tmp_push_dense_wait_times
=
-
1
;
int32_t
tmp_push_dense_wait_times
=
-
1
;
int32_t
tmp_push_sparse_wait_times
=
-
1
;
static
uint32_t
push_dense_wait_times
=
static
uint32_t
push_dense_wait_times
=
static_cast
<
uint32_t
>
(
tmp_push_dense_wait_times
);
static_cast
<
uint32_t
>
(
tmp_push_dense_wait_times
);
static
uint32_t
push_sparse_wait_times
=
static_cast
<
uint32_t
>
(
tmp_push_sparse_wait_times
);
if
(
push_dense_status_
.
size
()
>=
push_dense_wait_times
)
{
if
(
push_dense_status_
.
size
()
>=
push_dense_wait_times
)
{
for
(
auto
&
t
:
push_dense_status_
)
{
for
(
auto
&
t
:
push_dense_status_
)
{
t
.
wait
();
t
.
wait
();
...
@@ -297,6 +294,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
...
@@ -297,6 +294,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
}
}
if
(
need_to_push_sparse_
)
{
if
(
need_to_push_sparse_
)
{
int32_t
tmp_push_sparse_wait_times
=
-
1
;
static
uint32_t
push_sparse_wait_times
=
static_cast
<
uint32_t
>
(
tmp_push_sparse_wait_times
);
if
(
push_sparse_status_
.
size
()
>=
push_sparse_wait_times
)
{
if
(
push_sparse_status_
.
size
()
>=
push_sparse_wait_times
)
{
for
(
auto
&
t
:
push_sparse_status_
)
{
for
(
auto
&
t
:
push_sparse_status_
)
{
t
.
wait
();
t
.
wait
();
...
@@ -311,6 +311,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
...
@@ -311,6 +311,9 @@ void DownpourWorker::TrainFilesWithProfiler() {
VLOG
(
3
)
<<
"going to increase thread version"
;
VLOG
(
3
)
<<
"going to increase thread version"
;
VLOG
(
3
)
<<
"push dense table id size: "
VLOG
(
3
)
<<
"push dense table id size: "
<<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
<<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
}
if
(
need_to_push_dense_
)
{
for
(
size_t
i
=
0
;
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
++
i
)
{
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
uint64_t
tid
=
static_cast
<
uint64_t
>
(
...
@@ -381,9 +384,10 @@ void DownpourWorker::TrainFiles() {
...
@@ -381,9 +384,10 @@ void DownpourWorker::TrainFiles() {
}
}
}
}
if
(
need_to_push_sparse_
)
{
// push gradients here
// push gradients here
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
push_sparse_table_id_size
()
;
for
(
size_t
i
=
0
;
++
i
)
{
i
<
param_
.
program_config
(
0
).
push_sparse_table_id_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
program_config
(
0
).
push_sparse_table_id
(
i
));
param_
.
program_config
(
0
).
push_sparse_table_id
(
i
));
TableParameter
table
;
TableParameter
table
;
...
@@ -398,24 +402,23 @@ void DownpourWorker::TrainFiles() {
...
@@ -398,24 +402,23 @@ void DownpourWorker::TrainFiles() {
sparse_key_names_
[
tid
],
sparse_grad_names_
[
tid
],
table
.
emb_dim
(),
sparse_key_names_
[
tid
],
sparse_grad_names_
[
tid
],
table
.
emb_dim
(),
&
feature_grads_
[
tid
],
&
push_sparse_status_
);
&
feature_grads_
[
tid
],
&
push_sparse_status_
);
}
}
}
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
if
(
need_to_push_dense_
)
{
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
program_config
(
0
).
push_dense_table_id
(
i
));
param_
.
program_config
(
0
).
push_dense_table_id
(
i
));
fleet_ptr_
->
PushDenseVarsAsync
(
fleet_ptr_
->
PushDenseVarsAsync
(
*
thread_scope_
,
tid
,
dense_grad_names_
[
tid
],
&
push_sparse_status_
);
*
thread_scope_
,
tid
,
dense_grad_names_
[
tid
],
&
push_sparse_status_
);
}
}
VLOG
(
3
)
<<
"push sparse and
dense gradient done."
;
VLOG
(
3
)
<<
"push
dense gradient done."
;
// the following code should be more precise and clean
// the following code should be more precise and clean
// TODO(guru4elephant)
// TODO(guru4elephant)
int32_t
tmp_push_dense_wait_times
=
-
1
;
int32_t
tmp_push_dense_wait_times
=
-
1
;
int32_t
tmp_push_sparse_wait_times
=
-
1
;
static
uint32_t
push_dense_wait_times
=
static
uint32_t
push_dense_wait_times
=
static_cast
<
uint32_t
>
(
tmp_push_dense_wait_times
);
static_cast
<
uint32_t
>
(
tmp_push_dense_wait_times
);
static
uint32_t
push_sparse_wait_times
=
static_cast
<
uint32_t
>
(
tmp_push_sparse_wait_times
);
if
(
push_dense_status_
.
size
()
>=
push_dense_wait_times
)
{
if
(
push_dense_status_
.
size
()
>=
push_dense_wait_times
)
{
for
(
auto
&
t
:
push_dense_status_
)
{
for
(
auto
&
t
:
push_dense_status_
)
{
...
@@ -427,7 +430,13 @@ void DownpourWorker::TrainFiles() {
...
@@ -427,7 +430,13 @@ void DownpourWorker::TrainFiles() {
if
(
tmp_push_dense_wait_times
==
-
1
)
{
if
(
tmp_push_dense_wait_times
==
-
1
)
{
push_dense_status_
.
resize
(
0
);
push_dense_status_
.
resize
(
0
);
}
}
}
if
(
need_to_push_sparse_
)
{
VLOG
(
3
)
<<
"push sparse gradient done."
;
int32_t
tmp_push_sparse_wait_times
=
-
1
;
static
uint32_t
push_sparse_wait_times
=
static_cast
<
uint32_t
>
(
tmp_push_sparse_wait_times
);
if
(
push_sparse_status_
.
size
()
>=
push_sparse_wait_times
)
{
if
(
push_sparse_status_
.
size
()
>=
push_sparse_wait_times
)
{
for
(
auto
&
t
:
push_sparse_status_
)
{
for
(
auto
&
t
:
push_sparse_status_
)
{
t
.
wait
();
t
.
wait
();
...
@@ -438,13 +447,16 @@ void DownpourWorker::TrainFiles() {
...
@@ -438,13 +447,16 @@ void DownpourWorker::TrainFiles() {
if
(
tmp_push_sparse_wait_times
==
-
1
)
{
if
(
tmp_push_sparse_wait_times
==
-
1
)
{
push_sparse_status_
.
resize
(
0
);
push_sparse_status_
.
resize
(
0
);
}
}
}
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
if
(
need_to_push_dense_
)
{
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_
.
program_config
(
0
).
push_dense_table_id_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
program_config
(
0
).
push_dense_table_id
(
i
));
param_
.
program_config
(
0
).
push_dense_table_id
(
i
));
pull_dense_worker_
->
IncreaseThreadVersion
(
thread_id_
,
tid
);
pull_dense_worker_
->
IncreaseThreadVersion
(
thread_id_
,
tid
);
}
}
}
PrintFetchVars
();
PrintFetchVars
();
thread_scope_
->
DropKids
();
thread_scope_
->
DropKids
();
...
...
python/paddle/fluid/async_executor.py
浏览文件 @
d87ba58c
...
@@ -154,10 +154,8 @@ class AsyncExecutor(object):
...
@@ -154,10 +154,8 @@ class AsyncExecutor(object):
with
open
(
"trainer_desc.proto"
,
"w"
)
as
fout
:
with
open
(
"trainer_desc.proto"
,
"w"
)
as
fout
:
fout
.
write
(
trainer
.
_desc
())
fout
.
write
(
trainer
.
_desc
())
# define a trainer and a device_worker here
# define a trainer and a device_worker here
self
.
executor
.
run_from_files
(
program_desc
,
self
.
executor
.
run_from_files
(
program_desc
,
trainer
.
_desc
(),
debug
)
trainer
.
_desc
(),
debug
)
'''
def
run
(
self
,
def
run
(
self
,
program
,
program
,
data_feed
,
data_feed
,
...
@@ -228,8 +226,8 @@ class AsyncExecutor(object):
...
@@ -228,8 +226,8 @@ class AsyncExecutor(object):
self
.
executor
.
run_from_files
(
program_desc
,
self
.
executor
.
run_from_files
(
program_desc
,
data_feed
.
desc
(),
filelist
,
thread_num
,
data_feed
.
desc
(),
filelist
,
thread_num
,
fetch_var_names, mode, debug,
str(id(program_desc)))
fetch_var_names
,
mode
,
debug
,
'''
str
(
id
(
program_desc
)))
def
download_data
(
self
,
def
download_data
(
self
,
afs_path
,
afs_path
,
...
...
python/paddle/fluid/device_worker.py
浏览文件 @
d87ba58c
...
@@ -19,7 +19,10 @@ __all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
...
@@ -19,7 +19,10 @@ __all__ = ['DeviceWorker', 'Hogwild', 'DownpourSGD']
class
DeviceWorker
(
object
):
class
DeviceWorker
(
object
):
"""
"""
DeviceWorker is a abstract class, which generates worker desc.
DeviceWorker is a abstract class, which generates worker desc.
This class is an inner class that we do computation logics within
the implementation. For example, execution of a program or a graph.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
"""
"""
Init.
Init.
...
@@ -27,10 +30,16 @@ class DeviceWorker(object):
...
@@ -27,10 +30,16 @@ class DeviceWorker(object):
self
.
program_
=
None
self
.
program_
=
None
self
.
infer_
=
None
self
.
infer_
=
None
def
set_infer
(
self
,
infer
=
False
):
def
_set_infer
(
self
,
infer
=
False
):
"""
set inference flag for current device worker
Args:
infer(bool): whether to do inference
"""
self
.
infer_
=
infer
self
.
infer_
=
infer
def
set_fleet_desc
(
self
,
fleet_desc
):
def
_
set_fleet_desc
(
self
,
fleet_desc
):
"""
"""
Set fleet desc.
Set fleet desc.
...
@@ -39,7 +48,7 @@ class DeviceWorker(object):
...
@@ -39,7 +48,7 @@ class DeviceWorker(object):
"""
"""
self
.
fleet_desc_
=
fleet_desc
self
.
fleet_desc_
=
fleet_desc
def
set_program
(
self
,
program
):
def
_
set_program
(
self
,
program
):
"""
"""
Set program.
Set program.
...
@@ -48,7 +57,7 @@ class DeviceWorker(object):
...
@@ -48,7 +57,7 @@ class DeviceWorker(object):
"""
"""
self
.
program_
=
program
self
.
program_
=
program
def
gen_worker_desc
(
self
,
trainer_desc
):
def
_
gen_worker_desc
(
self
,
trainer_desc
):
"""
"""
Generator worker desc.
Generator worker desc.
...
@@ -65,13 +74,14 @@ class Hogwild(DeviceWorker):
...
@@ -65,13 +74,14 @@ class Hogwild(DeviceWorker):
Hogwild is a kind of SGD algorithm.
Hogwild is a kind of SGD algorithm.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
"""
"""
Init.
Init.
"""
"""
super
(
Hogwild
,
self
).
__init__
()
super
(
Hogwild
,
self
).
__init__
()
def
gen_worker_desc
(
self
,
trainer_desc
):
def
_
gen_worker_desc
(
self
,
trainer_desc
):
"""
"""
Generator worker desc, which device worker is HogwildWorker.
Generator worker desc, which device worker is HogwildWorker.
...
@@ -85,13 +95,15 @@ class DownpourSGD(DeviceWorker):
...
@@ -85,13 +95,15 @@ class DownpourSGD(DeviceWorker):
"""
"""
DownpourSGD is a kind of distributed SGD algorithm.
DownpourSGD is a kind of distributed SGD algorithm.
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
"""
"""
Init.
Init.
initialize downpourSGD device worker
"""
"""
super
(
DownpourSGD
,
self
).
__init__
()
super
(
DownpourSGD
,
self
).
__init__
()
def
gen_worker_desc
(
self
,
trainer_desc
):
def
_
gen_worker_desc
(
self
,
trainer_desc
):
"""
"""
Generator worker desc, which device worker is DownpourWorker.
Generator worker desc, which device worker is DownpourWorker.
...
@@ -162,6 +174,6 @@ class DownpourSGD(DeviceWorker):
...
@@ -162,6 +174,6 @@ class DownpourSGD(DeviceWorker):
class
DeviceWorkerFactory
(
object
):
class
DeviceWorkerFactory
(
object
):
def
create_device_worker
(
self
,
worker_type
):
def
_
create_device_worker
(
self
,
worker_type
):
classname
=
worker_type
.
capitalize
()
classname
=
worker_type
.
capitalize
()
return
globals
()[
classname
]()
return
globals
()[
classname
]()
python/paddle/fluid/executor.py
浏览文件 @
d87ba58c
...
@@ -637,23 +637,23 @@ class Executor(object):
...
@@ -637,23 +637,23 @@ class Executor(object):
assert
len
(
fetch_list
)
==
len
(
fetch_info
)
assert
len
(
fetch_list
)
==
len
(
fetch_info
)
compiled
=
isinstance
(
program
,
compiler
.
CompiledProgram
)
compiled
=
isinstance
(
program
,
compiler
.
CompiledProgram
)
if
not
compiled
:
if
not
compiled
:
trainer
=
TrainerFactory
().
create_trainer
(
program
.
_fleet_opt
)
trainer
=
TrainerFactory
().
_
create_trainer
(
program
.
_fleet_opt
)
trainer
.
set_program
(
program
)
trainer
.
_
set_program
(
program
)
else
:
else
:
trainer
=
TrainerFactory
().
create_trainer
(
trainer
=
TrainerFactory
().
_
create_trainer
(
program
.
program
.
_fleet_opt
)
program
.
program
.
_fleet_opt
)
trainer
.
set_program
(
program
.
program
)
trainer
.
_
set_program
(
program
.
program
)
if
thread
<=
0
:
if
thread
<=
0
:
if
dataset
.
thread_num
<=
0
:
if
dataset
.
thread_num
<=
0
:
raise
RuntimeError
(
raise
RuntimeError
(
"You should set thread num first, either in Dataset"
"You should set thread num first, either in Dataset"
"or in Executor.train_from_dataset"
)
"or in Executor.train_from_dataset"
)
else
:
else
:
trainer
.
set_thread
(
dataset
.
thread_num
)
trainer
.
_
set_thread
(
dataset
.
thread_num
)
else
:
else
:
trainer
.
set_thread
(
thread
)
trainer
.
_
set_thread
(
thread
)
trainer
.
set_debug
(
debug
)
trainer
.
_
set_debug
(
debug
)
trainer
.
set_fetch_var_and_info
(
fetch_list
,
fetch_info
,
print_period
)
trainer
.
_
set_fetch_var_and_info
(
fetch_list
,
fetch_info
,
print_period
)
return
trainer
return
trainer
def
infer_from_dataset
(
self
,
def
infer_from_dataset
(
self
,
...
@@ -679,7 +679,7 @@ class Executor(object):
...
@@ -679,7 +679,7 @@ class Executor(object):
for each run. default is global_scope
for each run. default is global_scope
thread(int): number of thread a user wants to run in this function. The actual number
thread(int): number of thread a user wants to run in this function. The actual number
of thread will be min(Dataset.thread_num, thread)
of thread will be min(Dataset.thread_num, thread)
debug(bool): whether a user wants to run
train
_from_dataset
debug(bool): whether a user wants to run
infer
_from_dataset
fetch_list(Variable List): fetch variable list, each variable
fetch_list(Variable List): fetch variable list, each variable
will be printed during training
will be printed during training
fetch_info(String List): print information for each variable
fetch_info(String List): print information for each variable
...
@@ -711,8 +711,8 @@ class Executor(object):
...
@@ -711,8 +711,8 @@ class Executor(object):
fetch_list
=
fetch_list
,
fetch_list
=
fetch_list
,
fetch_info
=
fetch_info
,
fetch_info
=
fetch_info
,
print_period
=
print_period
)
print_period
=
print_period
)
trainer
.
gen_trainer_desc
()
trainer
.
_
gen_trainer_desc
()
trainer
.
set_infer
(
True
)
trainer
.
_
set_infer
(
True
)
dataset
.
_prepare_to_run
()
dataset
.
_prepare_to_run
()
if
debug
:
if
debug
:
self
.
_dump_debug_info
(
program
=
program
,
trainer
=
trainer
)
self
.
_dump_debug_info
(
program
=
program
,
trainer
=
trainer
)
...
@@ -784,7 +784,7 @@ class Executor(object):
...
@@ -784,7 +784,7 @@ class Executor(object):
fetch_list
=
fetch_list
,
fetch_list
=
fetch_list
,
fetch_info
=
fetch_info
,
fetch_info
=
fetch_info
,
print_period
=
print_period
)
print_period
=
print_period
)
trainer
.
gen_trainer_desc
()
trainer
.
_
gen_trainer_desc
()
dataset
.
_prepare_to_run
()
dataset
.
_prepare_to_run
()
if
debug
:
if
debug
:
self
.
_dump_debug_info
(
program
=
program
,
trainer
=
trainer
)
self
.
_dump_debug_info
(
program
=
program
,
trainer
=
trainer
)
...
...
python/paddle/fluid/trainer_desc.py
浏览文件 @
d87ba58c
...
@@ -37,32 +37,32 @@ class TrainerDesc(object):
...
@@ -37,32 +37,32 @@ class TrainerDesc(object):
self
.
program_
=
None
self
.
program_
=
None
self
.
infer_
=
False
self
.
infer_
=
False
def
set_fetch_var_and_info
(
self
,
fetch_vars
,
fetch_info
,
print_period
):
def
_
set_fetch_var_and_info
(
self
,
fetch_vars
,
fetch_info
,
print_period
):
for
i
,
v
in
enumerate
(
fetch_vars
):
for
i
,
v
in
enumerate
(
fetch_vars
):
self
.
proto_desc
.
fetch_config
.
fetch_var_names
.
extend
([
v
.
name
])
self
.
proto_desc
.
fetch_config
.
fetch_var_names
.
extend
([
v
.
name
])
self
.
proto_desc
.
fetch_config
.
fetch_var_str_format
.
extend
(
self
.
proto_desc
.
fetch_config
.
fetch_var_str_format
.
extend
(
[
fetch_info
[
i
]])
[
fetch_info
[
i
]])
self
.
proto_desc
.
fetch_config
.
print_period
=
print_period
self
.
proto_desc
.
fetch_config
.
print_period
=
print_period
def
set_debug
(
self
,
debug
):
def
_
set_debug
(
self
,
debug
):
self
.
proto_desc
.
debug
=
debug
self
.
proto_desc
.
debug
=
debug
def
set_thread
(
self
,
thread_num
):
def
_
set_thread
(
self
,
thread_num
):
self
.
proto_desc
.
thread_num
=
thread_num
self
.
proto_desc
.
thread_num
=
thread_num
def
set_device_worker
(
self
,
device_worker
):
def
_
set_device_worker
(
self
,
device_worker
):
self
.
device_worker_
=
device_worker
self
.
device_worker_
=
device_worker
def
set_infer
(
self
,
infer
):
def
_
set_infer
(
self
,
infer
):
self
.
infer_
=
infer
self
.
infer_
=
infer
def
set_fleet_desc
(
self
,
fleet_desc
):
def
_
set_fleet_desc
(
self
,
fleet_desc
):
self
.
fleet_desc_
=
fleet_desc
self
.
fleet_desc_
=
fleet_desc
def
gen_trainer_desc
(
self
):
def
_
gen_trainer_desc
(
self
):
pass
pass
def
set_program
(
self
,
program
):
def
_
set_program
(
self
,
program
):
self
.
program_
=
program
self
.
program_
=
program
def
_desc
(
self
):
def
_desc
(
self
):
...
@@ -74,11 +74,11 @@ class MultiTrainer(TrainerDesc):
...
@@ -74,11 +74,11 @@ class MultiTrainer(TrainerDesc):
super
(
MultiTrainer
,
self
).
__init__
()
super
(
MultiTrainer
,
self
).
__init__
()
pass
pass
def
set_program
(
self
,
program
):
def
_
set_program
(
self
,
program
):
super
(
MultiTrainer
,
self
).
set_program
(
program
)
super
(
MultiTrainer
,
self
).
set_program
(
program
)
self
.
program_
=
program
self
.
program_
=
program
def
gen_trainer_desc
(
self
):
def
_
gen_trainer_desc
(
self
):
super
(
MultiTrainer
,
self
).
gen_trainer_desc
()
super
(
MultiTrainer
,
self
).
gen_trainer_desc
()
self
.
proto_desc
.
class_name
=
"MultiTrainer"
self
.
proto_desc
.
class_name
=
"MultiTrainer"
self
.
device_worker_
.
set_infer
(
self
.
infer_
)
self
.
device_worker_
.
set_infer
(
self
.
infer_
)
...
@@ -90,11 +90,11 @@ class DistMultiTrainer(TrainerDesc):
...
@@ -90,11 +90,11 @@ class DistMultiTrainer(TrainerDesc):
super
(
DistMultiTrainer
,
self
).
__init__
()
super
(
DistMultiTrainer
,
self
).
__init__
()
pass
pass
def
set_program
(
self
,
program
):
def
_
set_program
(
self
,
program
):
super
(
DistMultiTrainer
,
self
).
set_program
(
program
)
super
(
DistMultiTrainer
,
self
).
set_program
(
program
)
self
.
program_
=
program
self
.
program_
=
program
def
gen_trainer_desc
(
self
):
def
_
gen_trainer_desc
(
self
):
super
(
DistMultiTrainer
,
self
).
gen_trainer_desc
()
super
(
DistMultiTrainer
,
self
).
gen_trainer_desc
()
self
.
proto_desc
.
class_name
=
"DistMultiTrainer"
self
.
proto_desc
.
class_name
=
"DistMultiTrainer"
if
self
.
program_
==
None
:
if
self
.
program_
==
None
:
...
...
python/paddle/fluid/trainer_factory.py
浏览文件 @
d87ba58c
...
@@ -22,7 +22,7 @@ class TrainerFactory(object):
...
@@ -22,7 +22,7 @@ class TrainerFactory(object):
def
__init__
(
self
):
def
__init__
(
self
):
pass
pass
def
create_trainer
(
self
,
opt_info
=
None
):
def
_
create_trainer
(
self
,
opt_info
=
None
):
trainer
=
None
trainer
=
None
device_worker
=
None
device_worker
=
None
if
opt_info
==
None
:
if
opt_info
==
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录