Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
5812076e
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5812076e
编写于
5月 04, 2018
作者:
H
Helin Wang
提交者:
GitHub
5月 04, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10316 from helinwang/dist
Fluid new API: dist train without modifying code
上级
4558c0ec
8ee23da8
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
52 addition
and
4 deletion
+52
-4
python/paddle/fluid/trainer.py
python/paddle/fluid/trainer.py
+52
-4
未找到文件。
python/paddle/fluid/trainer.py
浏览文件 @
5812076e
...
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
core
import
framework
import
executor
...
...
@@ -20,6 +21,7 @@ import contextlib
# optimizer is same as the parameter of Trainer.__init__. Rename it to opt_module
import
optimizer
as
opt_module
import
distribute_transpiler
__all__
=
[
'Trainer'
,
...
...
@@ -76,22 +78,61 @@ class Trainer(object):
raise
TypeError
(
"The optimizer should be an instance of Optimizer"
)
optimizer
.
minimize
(
loss
)
optimize
_ops
,
params_grads
=
optimize
r
.
minimize
(
loss
)
self
.
place
=
Trainer
.
_check_and_get_place
(
place
)
self
.
dist_transpile_if_necessary
(
optimize_ops
,
params_grads
)
# 2. move the default_main_program to self.program and run the
# default_startup program on an empty core.Scope()
# Run startup program
exe
=
executor
.
Executor
(
place
)
exe
.
run
(
self
.
startup_program
,
scope
=
self
.
scope
)
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
place
)
exe
.
run
(
self
.
startup_program
)
if
param_path
:
# load params from param_path into scope
# TODO(yuyang): This depends on parameters implementation.
pass
# TODO(helin): support distributed training
def
dist_transpile_if_necessary
(
self
,
optimize_ops
,
params_grads
):
if
"PADDLE_TRAINING_ROLE"
not
in
os
.
environ
:
return
# the port of all pservers, needed by both trainer and pserver
port
=
os
.
getenv
(
"PADDLE_PSERVER_PORT"
,
"6174"
)
# comma separated ips of all pservers, needed by trainer and
# pserver
pserver_ips
=
os
.
getenv
(
"PADDLE_PSERVER_IPS"
,
""
)
eplist
=
[]
for
ip
in
pserver_ips
.
split
(
","
):
eplist
.
append
(
':'
.
join
([
ip
,
port
]))
pserver_endpoints
=
","
.
join
(
eplist
)
# total number of workers/trainers in the job, needed by
# trainer and pserver
trainers
=
int
(
os
.
getenv
(
"PADDLE_TRAINERS"
))
# the IP of the local machine, needed by pserver only
current_endpoint
=
os
.
getenv
(
"PADDLE_CURRENT_IP"
,
""
)
+
":"
+
port
# the unique trainer id, starting from 0, needed by trainer
# only
trainer_id
=
int
(
os
.
getenv
(
"PADDLE_TRAINER_ID"
,
"0"
))
# the role, should be either PSERVER or TRAINER
training_role
=
os
.
getenv
(
"PADDLE_TRAINING_ROLE"
)
with
self
.
_prog_and_scope_guard
():
t
=
distribute_transpiler
.
DistributeTranspiler
()
t
.
transpile
(
trainer_id
,
pservers
=
pserver_endpoints
,
trainers
=
trainers
)
if
training_role
==
"PSERVER"
:
self
.
train_program
=
t
.
get_pserver_program
(
current_endpoint
)
self
.
startup_program
=
t
.
get_startup_program
(
current_endpoint
,
self
.
train_program
)
elif
training_role
==
"TRAINER"
:
self
.
train_program
=
t
.
get_trainer_program
()
else
:
raise
ValueError
(
'TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def
train
(
self
,
num_epochs
,
...
...
@@ -117,6 +158,13 @@ class Trainer(object):
raise
NotImplementedError
(
"Parallel Executor version of trainer is not implemented"
)
training_role
=
os
.
getenv
(
"PADDLE_TRAINING_ROLE"
,
""
)
if
training_role
==
"PSERVER"
:
with
self
.
_prog_and_scope_guard
():
exe
=
executor
.
Executor
(
self
.
place
)
exe
.
run
()
return
self
.
_train_by_executor
(
num_epochs
,
event_handler
,
reader
,
feed_order
)
def
test
(
self
,
reader
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录