Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a25a716e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a25a716e
编写于
9月 06, 2019
作者:
1
123malin
提交者:
GitHub
9月 06, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize fleet API: add input check for some interfaces (#18971)
* fleet api add input check, test=develop
上级
ed8f44ea
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
351 addition
and
29 deletion
+351
-29
paddle/fluid/API.spec
paddle/fluid/API.spec
+2
-2
python/paddle/fluid/incubate/fleet/base/fleet_base.py
python/paddle/fluid/incubate/fleet/base/fleet_base.py
+5
-2
python/paddle/fluid/incubate/fleet/base/role_maker.py
python/paddle/fluid/incubate/fleet/base/role_maker.py
+49
-17
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
.../fleet/parameter_server/distribute_transpiler/__init__.py
+54
-6
python/paddle/fluid/tests/unittests/test_fleet_api_input.py
python/paddle/fluid/tests/unittests/test_fleet_api_input.py
+208
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+33
-2
未找到文件。
paddle/fluid/API.spec
浏览文件 @
a25a716e
...
...
@@ -34,7 +34,7 @@ paddle.fluid.DistributeTranspiler.transpile (ArgSpec(args=['self', 'trainer_id',
paddle.fluid.memory_optimize (ArgSpec(args=['input_program', 'skip_opt_set', 'print_log', 'level', 'skip_grads'], varargs=None, keywords=None, defaults=(None, False, 0, True)), ('document', '2348247f684bfd5bb9466470f35be064'))
paddle.fluid.release_memory (ArgSpec(args=['input_program', 'skip_opt_set'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd38c5b8b2b2e0bb19bcf1b581a80a7e4'))
paddle.fluid.DistributeTranspilerConfig ('paddle.fluid.transpiler.distribute_transpiler.DistributeTranspilerConfig', ('document', '550b8c767a8ae1a2eb74b18924ddc975'))
paddle.fluid.DistributeTranspilerConfig.__init__
paddle.fluid.DistributeTranspilerConfig.__init__
(ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.ParallelExecutor ('paddle.fluid.parallel_executor.ParallelExecutor', ('document', '2b4d2e859f2e0c6161f4fed995f7956d'))
paddle.fluid.ParallelExecutor.__init__ (ArgSpec(args=['self', 'use_cuda', 'loss_name', 'main_program', 'share_vars_from', 'exec_strategy', 'build_strategy', 'num_trainers', 'trainer_id', 'scope'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 1, 0, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.ParallelExecutor.drop_local_exe_scopes (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '77c739744ea5708b80fb1b37cc89db40'))
...
...
@@ -878,7 +878,7 @@ paddle.fluid.transpiler.RoundRobin.__init__ (ArgSpec(args=['self', 'pserver_endp
paddle.fluid.transpiler.RoundRobin.dispatch (ArgSpec(args=['self', 'varlist'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.transpiler.RoundRobin.reset (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.transpiler.DistributeTranspilerConfig ('paddle.fluid.transpiler.distribute_transpiler.DistributeTranspilerConfig', ('document', '550b8c767a8ae1a2eb74b18924ddc975'))
paddle.fluid.transpiler.DistributeTranspilerConfig.__init__
paddle.fluid.transpiler.DistributeTranspilerConfig.__init__
(ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.nets.simple_img_conv_pool (ArgSpec(args=['input', 'num_filters', 'filter_size', 'pool_size', 'pool_stride', 'pool_padding', 'pool_type', 'global_pooling', 'conv_stride', 'conv_padding', 'conv_dilation', 'conv_groups', 'param_attr', 'bias_attr', 'act', 'use_cudnn'], varargs=None, keywords=None, defaults=(0, 'max', False, 1, 0, 1, 1, None, None, None, True)), ('document', '13f01ff80e8dfbd3427d90cf49bc62eb'))
paddle.fluid.nets.sequence_conv_pool (ArgSpec(args=['input', 'num_filters', 'filter_size', 'param_attr', 'act', 'pool_type', 'bias_attr'], varargs=None, keywords=None, defaults=(None, 'sigmoid', 'max', None)), ('document', 'd6a1e527b53f5cc15594fee307dfc5cf'))
paddle.fluid.nets.glu (ArgSpec(args=['input', 'dim'], varargs=None, keywords=None, defaults=(-1,)), ('document', 'b87bacfc70dd3477ed25ef14aa01389a'))
...
...
python/paddle/fluid/incubate/fleet/base/fleet_base.py
浏览文件 @
a25a716e
...
...
@@ -159,6 +159,9 @@ class Fleet(object):
Returns:
list: files belongs to this worker.
"""
if
not
isinstance
(
files
,
list
):
raise
TypeError
(
"files should be a list of file need to be read."
)
trainer_id
=
self
.
worker_index
()
trainers
=
self
.
worker_num
()
...
...
@@ -192,7 +195,7 @@ class Fleet(object):
self
.
_executor
=
Executor
(
fluid
.
CPUPlace
())
if
role_maker
and
not
isinstance
(
role_maker
,
RoleMakerBase
):
raise
Valu
eError
(
"role_maker must be an instance of RoleMakerBase"
)
raise
Typ
eError
(
"role_maker must be an instance of RoleMakerBase"
)
self
.
_role_maker
=
role_maker
self
.
_role_maker
.
generate_role
()
...
...
@@ -255,7 +258,7 @@ class DistributedOptimizer(object):
def
__init__
(
self
,
optimizer
,
strategy
=
None
):
if
not
isinstance
(
optimizer
,
SGD
.
__bases__
):
raise
Valu
eError
(
"optimizer must be an instance of Optimizer"
)
raise
Typ
eError
(
"optimizer must be an instance of Optimizer"
)
self
.
_optimizer
=
optimizer
self
.
_strategy
=
strategy
...
...
python/paddle/fluid/incubate/fleet/base/role_maker.py
浏览文件 @
a25a716e
...
...
@@ -435,30 +435,46 @@ class UserDefinedRoleMaker(RoleMakerBase):
"""
super
(
UserDefinedRoleMaker
,
self
).
__init__
()
if
not
isinstance
(
current_id
,
int
):
raise
TypeError
(
"current_id must be as int"
)
if
not
isinstance
(
server_endpoints
,
list
):
raise
TypeError
(
"server_endpoints must be as string list"
)
elif
len
(
server_endpoints
)
<=
0
:
raise
ValueError
(
"the length of server_endpoints list must be greater than 0"
)
elif
len
(
server_endpoints
)
!=
len
(
set
(
server_endpoints
)):
raise
ValueError
(
"server_endpoints can't have duplicate elements"
)
else
:
if
current_id
<
0
:
raise
ValueError
(
"current_id must be gather or equal 0"
)
self
.
_current_id
=
current_id
for
server_endpoint
in
server_endpoints
:
if
not
isinstance
(
server_endpoint
,
str
):
raise
TypeError
(
"every element in server_endpoints list must be as string"
)
self
.
_server_endpoints
=
server_endpoints
if
role
!=
Role
.
WORKER
and
role
!=
Role
.
SERVER
:
raise
TypeError
(
"role must be as Role"
)
else
:
self
.
_role
=
role
if
not
isinstance
(
current_id
,
int
):
raise
TypeError
(
"current_id must be as int"
)
else
:
if
current_id
<
0
:
raise
ValueError
(
"current_id must be greater than or equal to 0"
)
elif
self
.
_role
==
Role
.
SERVER
and
current_id
>=
len
(
server_endpoints
):
raise
ValueError
(
"if role is Role.SERVER, current_id must be less than or equal to len(server_endpoints) - 1"
)
self
.
_current_id
=
current_id
if
not
isinstance
(
worker_num
,
int
):
raise
TypeError
(
"worker_num must be as int"
)
else
:
if
worker_num
<
0
:
raise
ValueError
(
"worker_num must be g
ather or equal
0"
)
if
worker_num
<
=
0
:
raise
ValueError
(
"worker_num must be g
reater than
0"
)
self
.
_worker_num
=
worker_num
if
not
isinstance
(
server_endpoints
,
list
):
raise
TypeError
(
"server_endpoints must be as string list"
)
else
:
self
.
_server_endpoints
=
server_endpoints
def
generate_role
(
self
):
self
.
_role_is_generated
=
True
...
...
@@ -489,17 +505,33 @@ class UserDefinedCollectiveRoleMaker(RoleMakerBase):
"""
super
(
UserDefinedCollectiveRoleMaker
,
self
).
__init__
()
if
not
isinstance
(
worker_endpoints
,
list
):
raise
TypeError
(
"worker_endpoints must be as string list"
)
elif
len
(
worker_endpoints
)
<=
0
:
raise
ValueError
(
"the length of worker_endpoints list must be greater than 0"
)
elif
len
(
worker_endpoints
)
!=
len
(
set
(
worker_endpoints
)):
raise
ValueError
(
"worker_endpoints can't have duplicate elements"
)
else
:
for
worker_endpoint
in
worker_endpoints
:
if
not
isinstance
(
worker_endpoint
,
str
):
raise
TypeError
(
"every element in worker_endpoints list must be as string"
)
self
.
_worker_endpoints
=
worker_endpoints
if
not
isinstance
(
current_id
,
int
):
raise
TypeError
(
"current_id must be as int"
)
else
:
if
current_id
<
0
:
raise
ValueError
(
"current_id must be greater or equal 0"
)
raise
ValueError
(
"current_id must be greater than or equal to 0"
)
elif
current_id
>=
len
(
worker_endpoints
):
raise
ValueError
(
"current_id must be less than or equal to len(worker_endpoints) - 1"
)
self
.
_current_id
=
current_id
if
not
isinstance
(
worker_endpoints
,
list
):
raise
TypeError
(
"worker_endpoints must be as string list"
)
else
:
self
.
_worker_endpoints
=
worker_endpoints
self
.
_worker_num
=
len
(
self
.
_worker_endpoints
)
def
generate_role
(
self
):
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
浏览文件 @
a25a716e
...
...
@@ -19,6 +19,9 @@ from paddle.fluid.communicator import Communicator
from
paddle.fluid.framework
import
default_main_program
from
paddle.fluid.framework
import
default_startup_program
from
paddle.fluid.framework
import
Program
from
paddle.fluid.compiler
import
CompiledProgram
from
paddle.fluid.executor
import
Executor
from
paddle.fluid.parallel_executor
import
ParallelExecutor
from
paddle.fluid.optimizer
import
Optimizer
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspiler
as
OriginTranspiler
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
...
...
@@ -134,7 +137,7 @@ class DistributedTranspiler(Fleet):
Args:
optimizer(Optimizer): The executor to run for init server.
strategy(
dict
): Extra properties for distributed optimizer.
strategy(
DistributeTranspilerConfig
): Extra properties for distributed optimizer.
Returns:
TranspilerOptimizer: subclass of DistributedOptimizer.
...
...
@@ -156,7 +159,21 @@ class DistributedTranspiler(Fleet):
Prune the given `main_program` to build a new program especially for inference,
and then save it and all related parameters to given `dirname` by the `executor`.
"""
if
isinstance
(
executor
,
ParallelExecutor
):
raise
TypeError
(
"in fleet.save_inference_model() function, executor must be as Executor type, ParallelExecutor is not allowed"
)
if
not
isinstance
(
executor
,
Executor
):
raise
TypeError
(
"in fleet.save_inference_model() function, executor must be as Executor type"
)
if
main_program
is
not
None
:
if
isinstance
(
main_program
,
CompiledProgram
):
raise
TypeError
(
"in fleet.save_inference_model() function, main_program must be as Program type, CompiledProgram is not allowed"
)
io
.
save_inference_model
(
dirname
,
feeded_var_names
,
target_vars
,
executor
,
main_program
,
None
,
None
,
export_for_deployment
)
...
...
@@ -186,10 +203,24 @@ class DistributedTranspiler(Fleet):
files, set `filename` None; if you would like to save all variables in a
single file, use `filename` to specify the file name.
"""
if
isinstance
(
executor
,
ParallelExecutor
):
raise
TypeError
(
"in fleet.save_persistables() function, executor must be as Executor type, ParallelExecutor is not allowed"
)
if
not
isinstance
(
executor
,
Executor
):
raise
TypeError
(
"in fleet.save_persistables() function, executor must be as Executor type"
)
if
main_program
is
None
:
main_program
=
self
.
main_program
if
isinstance
(
main_program
,
CompiledProgram
):
raise
TypeError
(
"in fleet.save_persistables() function, main_program must be as Program type, CompiledProgram is not allowed"
)
if
not
main_program
.
_is_distributed
:
raise
ValueError
(
"main_program is for local, may not use fleet.save_persistables"
)
...
...
@@ -198,7 +229,7 @@ class DistributedTranspiler(Fleet):
def
_transpile
(
self
,
config
):
if
not
isinstance
(
config
,
DistributeTranspilerConfig
):
raise
Valu
eError
(
raise
Typ
eError
(
"config must be an instance of DistributeTranspilerConfig"
)
if
not
config
.
sync_mode
:
...
...
@@ -260,7 +291,7 @@ class TranspilerOptimizer(DistributedOptimizer):
if
strategy
:
if
not
isinstance
(
strategy
,
DistributeTranspilerConfig
):
raise
Valu
eError
(
raise
Typ
eError
(
"In {} mode, strategy must be an instance of DistributeTranspilerConfig"
.
format
(
fleet
.
_mode
))
else
:
...
...
@@ -321,17 +352,34 @@ class TranspilerOptimizer(DistributedOptimizer):
def
minimize
(
self
,
loss
,
scope
=
None
,
scope
s
=
None
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
"""
Add operations to minimize `loss` by updating `parameter_list`.
This method combines interface `backward()` and
`apply_gradients()` into one.
Args:
loss (Variable): loss variable to run optimizations.
scopes (None): TranspilerOptimizer doesn't need scope parameter.
startup_program (Program): startup_program for initializing parameters
in `parameter_list`.
parameter_list (list): list of Variables to update.
no_grad_set (set|None): set of Variables should be ignored.
Returns:
tuple: (optimize_ops, params_grads) which are, list of operators appended;
and list of (param, grad) Variables pair for optimization.
"""
if
isinstance
(
loss
,
list
):
raise
Valu
eError
(
raise
Typ
eError
(
"DistributedTranspiler's minimize can not accept loss with list"
)
if
isinstance
(
startup_program
,
list
):
raise
Valu
eError
(
raise
Typ
eError
(
"DistributedTranspiler's minimize can not accept program with list"
)
...
...
python/paddle/fluid/tests/unittests/test_fleet_api_input.py
0 → 100644
浏览文件 @
a25a716e
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
paddle.fluid
as
fluid
from
paddle.fluid.transpiler.distribute_transpiler
import
DistributeTranspilerConfig
from
paddle.fluid.incubate.fleet.base.role_maker
import
UserDefinedRoleMaker
from
paddle.fluid.incubate.fleet.base.role_maker
import
UserDefinedCollectiveRoleMaker
from
paddle.fluid.incubate.fleet.base.role_maker
import
Role
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
TranspilerOptimizer
class
DistributeTranspilerConfigTest
(
unittest
.
TestCase
):
def
set_runtime_split_send_recv
(
self
,
config
,
value
):
config
.
runtime_split_send_recv
=
value
def
set_sync_mode
(
self
,
config
,
value
):
config
.
sync_mode
=
value
def
testConfig
(
self
):
config
=
DistributeTranspilerConfig
()
self
.
assertRaises
(
Exception
,
self
.
set_sync_mode
,
config
,
None
)
self
.
assertRaises
(
Exception
,
self
.
set_runtime_split_send_recv
,
config
,
None
)
self
.
assertRaises
(
Exception
,
self
.
set_runtime_split_send_recv
,
config
,
True
)
self
.
set_sync_mode
(
config
,
False
)
self
.
assertFalse
(
config
.
sync_mode
)
self
.
set_runtime_split_send_recv
(
config
,
True
)
self
.
assertRaises
(
Exception
,
self
.
set_sync_mode
,
config
,
True
)
class
FleetTest
(
unittest
.
TestCase
):
def
testInvalidInputs
(
self
):
self
.
assertRaises
(
Exception
,
fleet
.
split_files
,
"files"
)
self
.
assertRaises
(
Exception
,
fleet
.
init
,
"pserver"
)
data
=
fluid
.
layers
.
data
(
name
=
'X'
,
shape
=
[
1
],
dtype
=
'float32'
)
hidden
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
10
)
loss
=
fluid
.
layers
.
mean
(
hidden
)
adam
=
fluid
.
optimizer
.
Adam
()
adam
.
minimize
(
loss
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
pe
=
fluid
.
ParallelExecutor
(
use_cuda
=
False
,
loss_name
=
loss
.
name
)
self
.
assertRaises
(
Exception
,
fleet
.
save_inference_model
,
dirname
=
'/tmp/'
,
feeded_var_names
=
[
'X'
],
target_vars
=
[
loss
],
executor
=
pe
)
self
.
assertRaises
(
Exception
,
fleet
.
save_inference_model
,
dirname
=
'/tmp/'
,
feeded_var_names
=
[
'X'
],
target_vars
=
[
loss
],
executor
=
"executor"
)
compiled_prog
=
fluid
.
compiler
.
CompiledProgram
(
fluid
.
default_main_program
())
self
.
assertRaises
(
Exception
,
fleet
.
save_inference_model
,
dirname
=
'/tmp/'
,
feeded_var_names
=
[
'X'
],
target_vars
=
[
loss
],
executor
=
exe
,
main_program
=
compiled_prog
)
self
.
assertRaises
(
Exception
,
fleet
.
save_persistables
,
executor
=
pe
,
dirname
=
'/tmp/'
)
self
.
assertRaises
(
Exception
,
fleet
.
save_persistables
,
executor
=
"executor"
,
dirname
=
'/tmp/'
)
self
.
assertRaises
(
Exception
,
fleet
.
save_persistables
,
executor
=
exe
,
dirname
=
'/tmp/'
,
main_program
=
compiled_prog
)
self
.
assertRaises
(
Exception
,
fleet
.
_transpile
,
"config"
)
class
TranspilerOptimizerTest
(
unittest
.
TestCase
):
def
testInvalidInputs
(
self
):
self
.
assertRaises
(
Exception
,
TranspilerOptimizer
,
"Adam"
,
None
)
self
.
assertRaises
(
Exception
,
TranspilerOptimizer
,
fluid
.
optimizer
.
Adam
(
0.001
),
"strategy"
)
transpiler
=
TranspilerOptimizer
(
fluid
.
optimizer
.
Adam
(
0.001
))
self
.
assertRaises
(
Exception
,
transpiler
.
minimize
,
loss
=
[])
data
=
fluid
.
layers
.
data
(
name
=
'X'
,
shape
=
[
1
],
dtype
=
'float32'
)
hidden
=
fluid
.
layers
.
fc
(
input
=
data
,
size
=
10
)
loss
=
fluid
.
layers
.
mean
(
hidden
)
self
.
assertRaises
(
Exception
,
transpiler
.
minimize
,
loss
=
loss
.
name
,
startup_program
=
[])
class
UserDefinedRoleMakerTest
(
unittest
.
TestCase
):
def
createRoleMaker
(
self
,
current_id
=
0
,
role
=
Role
.
WORKER
,
worker_num
=
1
,
server_endpoints
=
[
"127.0.0.1:8080"
]):
role
=
UserDefinedRoleMaker
(
current_id
,
role
,
worker_num
,
server_endpoints
)
def
testRoleMaker
(
self
):
self
.
createRoleMaker
()
## test all invalid server_endpoints
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
server_endpoints
=
None
)
# server_endpoints must be as list
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
server_endpoints
=
[])
# server_endpoints can't be empty
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
server_endpoints
=
[
3
,
[]
])
# element in server_endpoints must be as string
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
server_endpoints
=
[
"127.0.0.1:8080"
,
"127.0.0.1:8080"
]
)
# element in server_endpoints can't be duplicate
## test all invalid current_id
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
current_id
=
"0"
)
# current_id must be as int
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
current_id
=-
1
)
# current_id must be greater than or equal to 0
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
current_id
=
1
,
role
=
Role
.
SERVER
,
server_endpoints
=
[
"127.0.0.1:8080"
]
)
# if role is server, current_id must be less than len(server_endpoints)
## test all invalid worker_num
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
worker_num
=
"1"
)
# worker_num must be as int
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
worker_num
=
0
)
# worker_num must be greater than 0
## test all invalid role
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
role
=
3
)
# role must be as Role(Role.WORKER=1, Role.SERVER=2)
class
UserDefinedCollectiveRoleMakerTest
(
unittest
.
TestCase
):
def
createRoleMaker
(
self
,
current_id
=
0
,
worker_endpoints
=
[
"127.0.0.1:8080"
]):
role
=
UserDefinedCollectiveRoleMaker
(
current_id
,
worker_endpoints
)
def
testRoleMaker
(
self
):
self
.
createRoleMaker
()
## test all invalid worker_endpoints
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
worker_endpoints
=
None
)
# worker_endpoints must be as list
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
worker_endpoints
=
[])
# worker_endpoints can't be empty
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
worker_endpoints
=
[
3
,
[]])
# element worker_endpoints must be as string
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
worker_endpoints
=
[
"127.0.0.1:8080"
,
"127.0.0.1:8080"
]
)
# element in worker_endpoints can't be duplicate
## test all invalid current_id
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
current_id
=
"0"
)
# current_id must be as int
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
current_id
=-
1
)
# current_id must be greater than or equal to 0
self
.
assertRaises
(
Exception
,
self
.
createRoleMaker
,
current_id
=
1
,
worker_endpoints
=
[
"127.0.0.1:8080"
]
)
# current_id must be less than len(worker_endpoints)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
a25a716e
...
...
@@ -163,8 +163,8 @@ class DistributeTranspilerConfig(object):
print_log
=
False
wait_port
=
True
# split the send recv var in runtime
runtime_split_send_recv
=
False
sync_mode
=
True
_
runtime_split_send_recv
=
False
_
sync_mode
=
True
nccl_comm_num
=
1
#The picture here illustrates the principle:
...
...
@@ -177,6 +177,37 @@ class DistributeTranspilerConfig(object):
# supported modes: grad_allreduce, local_sgd
collective_mode
=
None
def
__init__
(
self
):
pass
@
property
def
runtime_split_send_recv
(
self
):
return
self
.
_runtime_split_send_recv
@
runtime_split_send_recv
.
setter
def
runtime_split_send_recv
(
self
,
value
):
if
value
is
None
:
raise
ValueError
(
"runtime_split_send_recv can't be None"
)
if
value
and
self
.
_sync_mode
:
raise
ValueError
(
"if you want to set runtime_split_send_recv to be true, make ensure config.sync_mode is false at first"
)
self
.
_runtime_split_send_recv
=
value
@
property
def
sync_mode
(
self
):
return
self
.
_sync_mode
@
sync_mode
.
setter
def
sync_mode
(
self
,
value
):
if
value
is
None
:
raise
ValueError
(
"sync_mode can't be None"
)
if
value
and
self
.
_runtime_split_send_recv
:
raise
ValueError
(
"if you want to set sync_mode to be true, make ensure config.runtime_split_send_recv is false at first"
)
self
.
_sync_mode
=
value
class
DistributeTranspiler
(
object
):
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录