Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
a38b98cb
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a38b98cb
编写于
3月 26, 2019
作者:
X
xjqbest
提交者:
dongdaxiang
3月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix code style & runtime error
test=develop
上级
8e14d8f9
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
23 addition
and
25 deletion
+23
-25
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+3
-3
python/paddle/fluid/incubate/fleet/base/role_maker.py
python/paddle/fluid/incubate/fleet/base/role_maker.py
+12
-12
python/paddle/fluid/incubate/fleet/parameter_server/__init__.py
.../paddle/fluid/incubate/fleet/parameter_server/__init__.py
+1
-1
python/paddle/fluid/tests/unittests/test_dataset.py
python/paddle/fluid/tests/unittests/test_dataset.py
+7
-9
未找到文件。
python/paddle/fluid/dataset.py
浏览文件 @
a38b98cb
...
...
@@ -235,15 +235,15 @@ class InMemoryDataset(DatasetBase):
"""
trainer_num
=
1
if
fleet
is
not
None
:
fleet
.
fleet_instance
.
role_maker_
.
barrier_worker
()
fleet
.
fleet_instance
.
role_maker_
.
_
barrier_worker
()
trainer_num
=
fleet
.
worker_num
()
self
.
dataset
.
register_client2client_msg_handler
()
self
.
dataset
.
set_trainer_num
(
trainer_num
)
if
fleet
is
not
None
:
fleet
.
fleet_instance
.
role_maker_
.
barrier_worker
()
fleet
.
fleet_instance
.
role_maker_
.
_
barrier_worker
()
self
.
dataset
.
global_shuffle
()
if
fleet
is
not
None
:
fleet
.
fleet_instance
.
role_maker_
.
barrier_worker
()
fleet
.
fleet_instance
.
role_maker_
.
_
barrier_worker
()
class
QueueDataset
(
DatasetBase
):
...
...
python/paddle/fluid/incubate/fleet/base/role_maker.py
浏览文件 @
a38b98cb
...
...
@@ -98,7 +98,7 @@ class MPIRoleMaker(RoleMakerBase):
"""
all_gather(obj) will call MPI's allgather function
"""
self
.
barrier_all
()
self
.
_
barrier_all
()
return
self
.
comm_
.
allgather
(
obj
)
def
_barrier_all
(
self
):
...
...
@@ -112,7 +112,7 @@ class MPIRoleMaker(RoleMakerBase):
collect current distributed job's ip list
"""
if
self
.
ips_
==
None
:
self
.
ips_
=
self
.
comm_
.
allgather
(
self
.
get_local_ip
())
self
.
ips_
=
self
.
comm_
.
allgather
(
self
.
_
get_local_ip
())
return
self
.
ips_
def
_finalize
(
self
):
...
...
@@ -146,7 +146,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return whether current process is the first worker assigned by role maker
"""
if
self
.
_check_role_generation
():
return
self
.
is_worker
()
and
0
==
self
.
worker_index
()
return
self
.
_is_worker
()
and
0
==
self
.
_
worker_index
()
return
False
def
_is_worker
(
self
):
...
...
@@ -170,8 +170,8 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return the current number of worker
"""
if
self
.
_check_role_generation
():
if
self
.
is_worker
():
return
self
.
get_size
()
/
2
if
self
.
_
is_worker
():
return
self
.
_
get_size
()
/
2
return
0
def
_server_num
(
self
):
...
...
@@ -179,8 +179,8 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return the current number of server
"""
if
self
.
_check_role_generation
():
if
self
.
is_server
():
return
self
.
get_size
()
/
2
if
self
.
_
is_server
():
return
self
.
_
get_size
()
/
2
return
0
def
_worker_index
(
self
):
...
...
@@ -204,7 +204,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
barrier all workers in current distributed job
"""
if
self
.
_check_role_generation
():
if
self
.
is_worker
():
if
self
.
_
is_worker
():
self
.
node_type_comm_
.
barrier
()
def
_barrier_server
(
self
):
...
...
@@ -212,7 +212,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
barrier all servers in current distributed job
"""
if
self
.
_check_role_generation
():
if
self
.
is_server
():
if
self
.
_
is_server
():
self
.
node_type_comm_
.
barrier
()
def
_generate_role
(
self
):
...
...
@@ -221,10 +221,10 @@ class MPISymetricRoleMaker(MPIRoleMaker):
"""
if
not
self
.
role_is_generated_
:
# TODO(guru4elephant): only allow to be called once
self
.
trainer_endpoints_
=
self
.
get_ips
()
self
.
pserver_endpoints_
=
self
.
get_ips
()
self
.
trainer_endpoints_
=
self
.
_
get_ips
()
self
.
pserver_endpoints_
=
self
.
_
get_ips
()
if
0
==
self
.
get_rank
()
%
self
.
proc_per_node_
%
2
:
if
0
==
self
.
_
get_rank
()
%
self
.
proc_per_node_
%
2
:
self
.
node_type_
=
0
else
:
self
.
node_type_
=
1
...
...
python/paddle/fluid/incubate/fleet/parameter_server/__init__.py
浏览文件 @
a38b98cb
...
...
@@ -88,7 +88,7 @@ class Fleet(object):
stop(): will be called after a user finishes his/her training task. Fleet instance will be
destroyed when stop() is called.
"""
self
.
role_maker_
.
barrier_worker
()
self
.
role_maker_
.
_
barrier_worker
()
if
self
.
role_maker_
.
_is_first_worker
():
self
.
_fleet_ptr
.
stop_server
()
self
.
role_maker_
.
_barrier_worker
()
...
...
python/paddle/fluid/tests/unittests/test_dataset.py
浏览文件 @
a38b98cb
"""
dataset testcases
"""
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -21,13 +25,9 @@ import unittest
class
TestDataset
(
unittest
.
TestCase
):
"""
TestCases for Dataset.
"""
""" TestCases for Dataset. """
def
test_dataset_create
(
self
):
"""
Testcase for dataset create
"""
""" Testcase for dataset create """
try
:
dataset
=
fluid
.
DatasetFactory
().
create_dataset
(
"InMemoryDataset"
)
except
:
...
...
@@ -45,9 +45,7 @@ class TestDataset(unittest.TestCase):
self
.
assertTrue
(
True
)
def
test_dataset_config
(
self
):
"""
Testcase for dataset configuration
"""
""" Testcase for dataset configuration """
dataset
=
fluid
.
core
.
Dataset
(
"MultiSlotDataset"
)
dataset
.
set_thread_num
(
12
)
dataset
.
set_filelist
([
"a.txt"
,
"b.txt"
,
"c.txt"
])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录