Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
a25a716e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a25a716e
编写于
9月 06, 2019
作者:
1
123malin
提交者:
GitHub
9月 06, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize fleet API: add input check for some interfaces (#18971)
* fleet api add input check, test=develop
上级
ed8f44ea
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
351 addition
and
29 deletion
+351
-29
paddle/fluid/API.spec
paddle/fluid/API.spec
+2
-2
python/paddle/fluid/incubate/fleet/base/fleet_base.py
python/paddle/fluid/incubate/fleet/base/fleet_base.py
+5
-2
python/paddle/fluid/incubate/fleet/base/role_maker.py
python/paddle/fluid/incubate/fleet/base/role_maker.py
+49
-17
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
.../fleet/parameter_server/distribute_transpiler/__init__.py
+54
-6
python/paddle/fluid/tests/unittests/test_fleet_api_input.py
python/paddle/fluid/tests/unittests/test_fleet_api_input.py
+208
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+33
-2
未找到文件。
paddle/fluid/API.spec
浏览文件 @
a25a716e
...
...
@@ -34,7 +34,7 @@ paddle.fluid.DistributeTranspiler.transpile (ArgSpec(args=['self', 'trainer_id',
paddle.fluid.memory_optimize (ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level', 'skip_grads'], varargs=None, keywords=None, defaults=(None, False, 0, True)), ('document', '2348247f684bfd5bb9466470f35be064'))
paddle.fluid.release_memory (ArgSpec(args=['input_program', 'skip_opt_set'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd38c5b8b2b2e0bb19bcf1b581a80a7e4'))
paddle.fluid.DistributeTranspilerConfig ('paddle.fluid.transpiler.distribute_transpiler.DistributeTranspilerConfig', ('document', '550b8c767a8ae1a2eb74b18924ddc975'))
paddle.fluid.DistributeTranspilerConfig.__init__
paddle.fluid.DistributeTranspilerConfig.__init__
(ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.ParallelExecutor ('paddle.fluid.parallel_executor.ParallelExecutor', ('document', '2b4d2e859f2e0c6161f4fed995f7956d'))
paddle.fluid.ParallelExecutor.__init__ (ArgSpec(args=['self', 'use_cuda', 'loss_name', 'main_program', 'share_vars_from', 'exec_strategy', 'build_strategy', 'num_trainers', 'trainer_id', 'scope'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 1, 0, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.ParallelExecutor.drop_local_exe_scopes (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '77c739744ea5708b80fb1b37cc89db40'))
...
...
@@ -878,7 +878,7 @@ paddle.fluid.transpiler.RoundRobin.__init__ (ArgSpec(args=['self', 'pserver_endp
paddle.fluid.transpiler.RoundRobin.dispatch (ArgSpec(args=['self', 'varlist'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.transpiler.RoundRobin.reset (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.transpiler.DistributeTranspilerConfig ('paddle.fluid.transpiler.distribute_transpiler.DistributeTranspilerConfig', ('document', '550b8c767a8ae1a2eb74b18924ddc975'))
paddle.fluid.transpiler.DistributeTranspilerConfig.__init__
paddle.fluid.transpiler.DistributeTranspilerConfig.__init__
(ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.nets.simple_img_conv_pool (ArgSpec(args=['input', 'num_filters', 'filter_size', 'pool_size', 'pool_stride', 'pool_padding', 'pool_type', 'global_pooling', 'conv_stride', 'conv_padding', 'conv_dilation', 'conv_groups', 'param_attr', 'bias_attr', 'act', 'use_cudnn'], varargs=None, keywords=None, defaults=(0, 'max', False, 1, 0, 1, 1, None, None, None, True)), ('document', '13f01ff80e8dfbd3427d90cf49bc62eb'))
paddle.fluid.nets.sequence_conv_pool (ArgSpec(args=['input', 'num_filters', 'filter_size', 'param_attr', 'act', 'pool_type', 'bias_attr'], varargs=None, keywords=None, defaults=(None, 'sigmoid', 'max', None)), ('document', 'd6a1e527b53f5cc15594fee307dfc5cf'))
paddle.fluid.nets.glu (ArgSpec(args=['input', 'dim'], varargs=None, keywords=None, defaults=(-1,)), ('document', 'b87bacfc70dd3477ed25ef14aa01389a'))
...
...
python/paddle/fluid/incubate/fleet/base/fleet_base.py
浏览文件 @
a25a716e
...
...
@@ -159,6 +159,9 @@ class Fleet(object):
Returns:
list: files belongs to this worker.
"""
if
not
isinstance
(
files
,
list
):
raise
TypeError
(
"files should be a list of file need to be read."
)
trainer_id
=
self
.
worker_index
()
trainers
=
self
.
worker_num
()
...
...
@@ -192,7 +195,7 @@ class Fleet(object):
self
.
_executor
=
Executor
(
fluid
.
CPUPlace
())
if
role_maker
and
not
isinstance
(
role_maker
,
RoleMakerBase
):
raise
Valu
eError
(
"role_maker must be an instance of RoleMakerBase"
)
raise
Typ
eError
(
"role_maker must be an instance of RoleMakerBase"
)
self
.
_role_maker
=
role_maker
self
.
_role_maker
.
generate_role
()
...
...
@@ -255,7 +258,7 @@ class DistributedOptimizer(object):
def
__init__
(
self
,
optimizer
,
strategy
=
None
):
if
not
isinstance
(
optimizer
,
SGD
.
__bases__
):
raise
Valu
eError
(
"optimizer must be an instance of Optimizer"
)
raise
Typ
eError
(
"optimizer must be an instance of Optimizer"
)
self
.
_optimizer
=
optimizer
self
.
_strategy
=
strategy
...
...
python/paddle/fluid/incubate/fleet/base/role_maker.py
浏览文件 @
a25a716e
...
...
@@ -435,30 +435,46 @@ class UserDefinedRoleMaker(RoleMakerBase):
"""
super
(
UserDefinedRoleMaker
,
self
).
__init__
()
if
not
isinstance
(
current_id
,
int
):
raise
TypeError
(
"current_id must be as int"
)
if
not
isinstance
(
server_endpoints
,
list
):
raise
TypeError
(
"server_endpoints must be as string list"
)
elif
len
(
server_endpoints
)
<=
0
:
raise
ValueError
(
"the length of server_endpoints list must be greater than 0"
)
elif
len
(
server_endpoints
)
!=
len
(
set
(
server_endpoints
)):
raise
ValueError
(
"server_endpoints can't have duplicate elements"
)
else
:
if
current_id
<
0
:
raise
ValueError
(
"current_id must be gather or equal 0"
)
self
.
_current_id
=
current_id
for
server_endpoint
in
server_endpoints
:
if
not
isinstance
(
server_endpoint
,
str
):
raise
TypeError
(
"every element in server_endpoints list must be as string"
)
self
.
_server_endpoints
=
server_endpoints
if
role
!=
Role
.
WORKER
and
role
!=
Role
.
SERVER
:
raise
TypeError
(
"role must be as Role"
)
else
:
self
.
_role
=
role
if
not
isinstance
(
current_id
,
int
):
raise
TypeError
(
"current_id must be as int"
)
else
:
if
current_id
<
0
:
raise
ValueError
(
"current_id must be greater than or equal to 0"
)
elif
self
.
_role
==
Role
.
SERVER
and
current_id
>=
len
(
server_endpoints
):
raise
ValueError
(
"if role is Role.SERVER, current_id must be less than or equal to len(server_endpoints) - 1"
)
self
.
_current_id
=
current_id
if
not
isinstance
(
worker_num
,
int
):
raise
TypeError
(
"worker_num must be as int"
)
else
:
if
worker_num
<
0
:
raise
ValueError
(
"worker_num must be g
ather or equal
0"
)
if
worker_num
<
=
0
:
raise
ValueError
(
"worker_num must be g
reater than
0"
)
self
.
_worker_num
=
worker_num
if
not
isinstance
(
server_endpoints
,
list
):
raise
TypeError
(
"server_endpoints must be as string list"
)
else
:
self
.
_server_endpoints
=
server_endpoints
def
generate_role
(
self
):
self
.
_role_is_generated
=
True
...
...
@@ -489,17 +505,33 @@ class UserDefinedCollectiveRoleMaker(RoleMakerBase):
"""
super
(
UserDefinedCollectiveRoleMaker
,
self
).
__init__
()
if
not
isinstance
(
worker_endpoints
,
list
):
raise
TypeError
(
"worker_endpoints must be as string list"
)
elif
len
(
worker_endpoints
)
<=
0
:
raise
ValueError
(
"the length of worker_endpoints list must be greater than 0"
)
elif
len
(
worker_endpoints
)
!=
len
(
set
(
worker_endpoints
)):
raise
ValueError
(
"worker_endpoints can't have duplicate elements"
)
else
:
for
worker_endpoint
in
worker_endpoints
:
if
not
isinstance
(
worker_endpoint
,
str
):
raise
TypeError
(
"every element in worker_endpoints list must be as string"
)
self
.
_worker_endpoints
=
worker_endpoints
if
not
isinstance
(
current_id
,
int
):
raise
TypeError
(
"current_id must be as int"
)
else
:
if
current_id
<
0
:
raise
ValueError
(
"current_id must be greater or equal 0"
)
raise
ValueError
(
"current_id must be greater than or equal to 0"
)
elif
current_id
>=
len
(
worker_endpoints
):
raise
ValueError
(
"current_id must be less than or equal to len(worker_endpoints) - 1"
)
self
.
_current_id
=
current_id
if
not
isinstance
(
worker_endpoints
,
list
):
raise
TypeError
(
"worker_endpoints must be as string list"
)
else
:
self
.
_worker_endpoints
=
worker_endpoints
self
.
_worker_num
=
len
(
self
.
_worker_endpoints
)
def
generate_role
(
self
):
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
浏览文件 @
a25a716e
...
...
@@ -19,6 +19,9 @@ from paddle.fluid.communicator import Communicator
from
paddle.fluid.framework
import
default_main_program
from
paddle.fluid.framework
import
default_startup_program
from
paddle.fluid.framework
import
Program
from
paddle.fluid.compiler
import
CompiledProgram
from
paddle.fluid.executor
import
Executor
from
paddle.fluid.parallel_executor
import
ParallelExecutor
from
paddle.fluid.optimizer
import
Optimizer
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspiler
as
OriginTranspiler
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
...
...
@@ -134,7 +137,7 @@ class DistributedTranspiler(Fleet):
Args:
optimizer(Optimizer): The executor to run for init server.
strategy(
dict
): Extra properties for distributed optimizer.
strategy(
DistributeTranspilerConfig
): Extra properties for distributed optimizer.
Returns:
TranspilerOptimizer: subclass of DistributedOptimizer.
...
...
@@ -156,7 +159,21 @@ class DistributedTranspiler(Fleet):
Prune the given `main_program` to build a new program especially for inference,
and then save it and all related parameters to given `dirname` by the `executor`.
"""
if
isinstance
(
executor
,
ParallelExecutor
):
raise
TypeError
(
"in fleet.save_inference_model() function, executor must be as Executor type, ParallelExecutor is not allowed"
)
if
not
isinstance
(
executor
,
Executor
):
raise
TypeError
(
"in fleet.save_inference_model() function, executor must be as Executor type"
)
if
main_program
is
not
None
:
if
isinstance
(
main_program
,
CompiledProgram
):
raise
TypeError
(
"in fleet.save_inference_model() function, main_program must be as Program type, CompiledProgram is not allowed"
)
io
.
save_inference_model
(
dirname
,
feeded_var_names
,
target_vars
,
executor
,
main_program
,
None
,
None
,
export_for_deployment
)
...
...
@@ -186,10 +203,24 @@ class DistributedTranspiler(Fleet):
files, set `filename` None; if you would like to save all variables in a
single file, use `filename` to specify the file name.
"""
if
isinstance
(
executor
,
ParallelExecutor
):
raise
TypeError
(
"in fleet.save_persistables() function, executor must be as Executor type, ParallelExecutor is not allowed"
)
if
not
isinstance
(
executor
,
Executor
):
raise
TypeError
(
"in fleet.save_persistables() function, executor must be as Executor type"
)
if
main_program
is
None
:
main_program
=
self
.
main_program
if
isinstance
(
main_program
,
CompiledProgram
):
raise
TypeError
(
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
)
if
not
main_program
.
_is_distributed
:
raise
ValueError
(
"main_program is for local, may not use fleet.save_persistables"
)
...
...
@@ -198,7 +229,7 @@ class DistributedTranspiler(Fleet):
def
_transpile
(
self
,
config
):
if
not
isinstance
(
config
,
DistributeTranspilerConfig
):
raise
Valu
eError
(
raise
Typ
eError
(
"config must be an instance of DistributeTranspilerConfig"
)
if
not
config
.
sync_mode
:
...
...
@@ -260,7 +291,7 @@ class TranspilerOptimizer(DistributedOptimizer):
if
strategy
:
if
not
isinstance
(
strategy
,
DistributeTranspilerConfig
):
raise
Valu
eError
(
raise
Typ
eError
(
"In {} mode, strategy must be an instance of DistributeTranspilerConfig"
.
format
(
fleet
.
_mode
))
else
:
...
...
@@ -321,17 +352,34 @@ class TranspilerOptimizer(DistributedOptimizer):
def
minimize
(
self
,
loss
,
scope
=
None
,
scope
s
=
None
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
"""
Add operations to minimize `loss` by updating `parameter_list`.
This method combines interface `backward()` and
`apply_gradients()` into one.
Args:
loss (Variable): loss variable to run optimizations.
scopes (None): TranspilerOptimizer doesn't need scope parameter.
startup_program (Program): startup_program for initializing parameters
in `parameter_list`.
parameter_list (list): list of Variables to update.
no_grad_set (set|None): set of Variables should be ignored.
Returns:
tuple: (optimize_ops, params_grads) which are, list of operators appended;
and list of (param, grad) Variables pair for optimization.
"""
if
isinstance
(
loss
,
list
):
raise
Valu
eError
(
raise
Typ
eError
(
"DistributedTranspiler's minimize can not accept loss with list"
)
if
isinstance
(
startup_program
,
list
):
raise
Valu
eError
(
raise
Typ
eError
(
"DistributedTranspiler's minimize can not accept program with list"
)
...
...
python/paddle/fluid/tests/unittests/test_fleet_api_input.py
0 → 100644
浏览文件 @
a25a716e
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
from
paddle.fluid.incubate.fleet.base.role_maker
import
UserDefinedRoleMaker
from
paddle.fluid.incubate.fleet.base.role_maker
import
UserDefinedCollectiveRoleMaker
from
paddle.fluid.incubate.fleet.base.role_maker
import
Role
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
TranspilerOptimizer
class
DistributeTranspilerConfigTest
(
unittest
.
TestCase
):
def
set_runtime_split_send_recv
(
self
,
config
,
value
):
config
.
runtime_split_send_recv
=
value
def
set_sync_mode
(
self
,
config
,
value
):
config
.
sync_mode
=
value
def
testConfig
(
self
):
config
=
DistributeTranspilerConfig
()
self
.
assertRaises
(
Exception
,
self
.
set_sync_mode
,
config
,
None
)
self
.
assertRaises
(
Exception
,
self
.
set_runtime_split_send_recv
,
config
,
None
)
self
.
assertRaises
(
Exception
,
self
.
set_runtime_split_send_recv
,
config
,
True
)
self
.
set_sync_mode
(
config
,
False
)
self
.
assertFalse
(
config
.
sync_mode
)
self
.
set_runtime_split_send_recv
(
config
,
True
)
self
.
assertRaises
(
Exception
,
self
.
set_sync_mode
,
config
,
True
)
class
FleetTest
(
unittest
.
TestCase
):
def
testInvalidInputs
(
self
):
self
.
assertRaises
(
Exception
,
fleet
.
split_files
,
"files"
)
self
.
assertRaises
(
Exception
,
fleet
.
init
,
"pserver"
)
data
=
fluid
.
layers
.
data
(
name
=
'X'
,
shape
=
[
1
],
dtype
=
'float32'
)
hidden
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
10
)
loss
=
fluid
.
layers
.
mean
(
hidden
)
adam
=
fluid
.
optimizer
.
Adam
()
adam
.
minimize
(
loss
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
pe
=
fluid
.
ParallelExecutor
(
use_cuda
=
False
,
loss_name
=
loss
.
name
)
self
.
assertRaises
(
Exception
,
fleet
.
save_inference_model
,
dirname
=
'/tmp/'
,
feeded_var_names
=
[
'X'
],
target_vars
=
[
loss
],
executor
=
pe
)
self
.
assertRaises
(
Exception
,
fleet
.
save_inference_model
,
dirname
=
'/tmp/'
,
feeded_var_names
=
[
'X'
],
target_vars
=
[
loss
],
executor
=
"executor"
)
compiled_prog
=
fluid
.
compiler
.
CompiledProgram
(
fluid
.
default_main_program
())
self
.
assertRaises
(
Exception
,
fleet
.
save_inference_model
,
dirname
=
'/tmp/'
,
feeded_var_names
=
[
'X'
],
target_vars
=
[
loss
],
executor
=
exe
,
main_program
=
compiled_prog
)
self
.
assertRaises
(
Exception
,
fleet
.
save_persistables
,
executor
=
pe
,
dirname
=
'/tmp/'
)
self
.
assertRaises
(
Exception
,
fleet
.
save_persistables
,
executor
=
"executor"
,
dirname
=
'/tmp/'
)
self
.
assertRaises
(
Exception
,
fleet
.
save_persistables
,
executor
=
exe
,
dirname
=
'/tmp/'
,
main_program
=
compiled_prog
)
self
.
assertRaises
(
Exception
,
fleet
.
_transpile
,
"config"
)
class
TranspilerOptimizerTest
(
unittest
.
TestCase
):
def
testInvalidInputs
(
self
):
self
.
assertRaises
(
Exception
,
TranspilerOptimizer
,
"Adam"
,
None
)
self
.
assertRaises
(
Exception
,
TranspilerOptimizer
,
fluid
.
optimizer
.
Adam
(
0.001
),
"strategy"
)
transpiler
=
TranspilerOptimizer
(
fluid
.
optimizer
.
Adam
(
0.001
))
self
.
assertRaises
(
Exception
,
transpiler
.
minimize
,
loss
=
[])
data
=
fluid
.
layers
.
data
(
name
=
'X'
,
shape
=
[
1
],
dtype
=
'float32'
)
hidden
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
10
)
loss
=
fluid
.
layers
.
mean
(
hidden
)
self
.
assertRaises
(
Exception
,
transpiler
.
minimize
,
loss
=
loss
.
name
,
startup_program
=
[])
class
UserDefinedRoleMakerTest
(
unittest
.
TestCase
):
def
createRoleMaker
(
self
,
current_id
=
0
,
role
=
Role
.
WORKER
,
worker_num
=
1
,
server_endpoints
=
[
"127.0.0.1:8080"
]):
role
=
UserDefinedRoleMaker
(
current_id
,
role
,
worker_num
,
server_endpoints
)
def
testRoleMaker
(
self
):
self
.
createRoleMaker
()
## test all invalid server_endpoints
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
server_endpoints
=
None
)
# server_endpoints must be as list
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
server_endpoints
=
[])
# server_endpoints can't be empty
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
server_endpoints
=
[
3
,
[]
])
# element in server_endpoints must be as string
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
server_endpoints
=
[
"127.0.0.1:8080"
,
"127.0.0.1:8080"
]
)
# element in server_endpoints can't be duplicate
## test all invalid current_id
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
current_id
=
"0"
)
# current_id must be as int
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
current_id
=-
1
)
# current_id must be greater than or equal to 0
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
current_id
=
1
,
role
=
Role
.
SERVER
,
server_endpoints
=
[
"127.0.0.1:8080"
]
)
# if role is server, current_id must be less than len(server_endpoints)
## test all invalid worker_num
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
worker_num
=
"1"
)
# worker_num must be as int
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
worker_num
=
0
)
# worker_num must be greater than 0
## test all invalid role
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
role
=
3
)
# role must be as Role(Role.WORKER=1, Role.SERVER=2)
class
UserDefinedCollectiveRoleMakerTest
(
unittest
.
TestCase
):
def
createRoleMaker
(
self
,
current_id
=
0
,
worker_endpoints
=
[
"127.0.0.1:8080"
]):
role
=
UserDefinedCollectiveRoleMaker
(
current_id
,
worker_endpoints
)
def
testRoleMaker
(
self
):
self
.
createRoleMaker
()
## test all invalid worker_endpoints
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
worker_endpoints
=
None
)
# worker_endpoints must be as list
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
worker_endpoints
=
[])
# worker_endpoints can't be empty
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
worker_endpoints
=
[
3
,
[]])
# element worker_endpoints must be as string
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
worker_endpoints
=
[
"127.0.0.1:8080"
,
"127.0.0.1:8080"
]
)
# element in worker_endpoints can't be duplicate
## test all invalid current_id
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
current_id
=
"0"
)
# current_id must be as int
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
current_id
=-
1
)
# current_id must be greater than or equal to 0
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
current_id
=
1
,
worker_endpoints
=
[
"127.0.0.1:8080"
]
)
# current_id must be less than len(worker_endpoints)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
a25a716e
...
...
@@ -163,8 +163,8 @@ class DistributeTranspilerConfig(object):
print_log
=
False
wait_port
=
True
# split the send recv var in runtime
runtime_split_send_recv
=
False
sync_mode
=
True
_
runtime_split_send_recv
=
False
_
sync_mode
=
True
nccl_comm_num
=
1
#The picture here illustrates the principle:
...
...
@@ -177,6 +177,37 @@ class DistributeTranspilerConfig(object):
# supported modes: grad_allreduce, local_sgd
collective_mode
=
None
def
__init__
(
self
):
pass
@
property
def
runtime_split_send_recv
(
self
):
return
self
.
_runtime_split_send_recv
@
runtime_split_send_recv
.
setter
def
runtime_split_send_recv
(
self
,
value
):
if
value
is
None
:
raise
ValueError
(
"runtime_split_send_recv can't be None"
)
if
value
and
self
.
_sync_mode
:
raise
ValueError
(
"if you want to set runtime_split_send_recv to be true, make ensure config.sync_mode is false at first"
)
self
.
_runtime_split_send_recv
=
value
@
property
def
sync_mode
(
self
):
return
self
.
_sync_mode
@
sync_mode
.
setter
def
sync_mode
(
self
,
value
):
if
value
is
None
:
raise
ValueError
(
"sync_mode can't be None"
)
if
value
and
self
.
_runtime_split_send_recv
:
raise
ValueError
(
"if you want to set sync_mode to be true, make ensure config.runtime_split_send_recv is false at first"
)
self
.
_sync_mode
=
value
class
DistributeTranspiler
(
object
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录