Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f84b54eb
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f84b54eb
编写于
3月 15, 2022
作者:
Y
Yulong Ao
提交者:
GitHub
3月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto parallel] Redesign the tuner for auto parallel (#40121)
* [Auto Parallel] Redesign the tunner for Auto Parallel
上级
0c333543
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
680 addition
and
0 deletion
+680
-0
python/paddle/distributed/auto_parallel/tuner/__init__.py
python/paddle/distributed/auto_parallel/tuner/__init__.py
+13
-0
python/paddle/distributed/auto_parallel/tuner/storable.py
python/paddle/distributed/auto_parallel/tuner/storable.py
+36
-0
python/paddle/distributed/auto_parallel/tuner/tunable_space.py
...n/paddle/distributed/auto_parallel/tuner/tunable_space.py
+151
-0
python/paddle/distributed/auto_parallel/tuner/tunable_variable.py
...addle/distributed/auto_parallel/tuner/tunable_variable.py
+242
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_space.py
...fluid/tests/unittests/auto_parallel/test_tunable_space.py
+138
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_variable.py
...id/tests/unittests/auto_parallel/test_tunable_variable.py
+99
-0
python/setup.py.in
python/setup.py.in
+1
-0
未找到文件。
python/paddle/distributed/auto_parallel/tuner/__init__.py
0 → 100644
浏览文件 @
f84b54eb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python/paddle/distributed/auto_parallel/tuner/storable.py
0 → 100644
浏览文件 @
f84b54eb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
class
Storable
(
object
):
def
get_state
(
self
):
raise
NotImplementedError
def
set_state
(
self
,
state
):
raise
NotImplementedError
def
save
(
self
,
path
):
state
=
self
.
get_state
()
state_json
=
json
.
dumps
(
state
)
with
open
(
path
,
"w"
)
as
f
:
f
.
write
(
state_json
)
return
str
(
path
)
def
load
(
self
,
path
):
with
open
(
path
,
"r"
)
as
f
:
state_data
=
f
.
read
()
state
=
json
.
loads
(
state_data
)
self
.
set_state
(
state
)
python/paddle/distributed/auto_parallel/tuner/tunable_space.py
0 → 100644
浏览文件 @
f84b54eb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
collections
import
contextlib
import
copy
import
math
import
random
import
numpy
as
np
from
.tunable_variable
import
Boolean
from
.tunable_variable
import
Fixed
from
.tunable_variable
import
Choice
from
.tunable_variable
import
IntRange
from
.tunable_variable
import
FloatRange
class
TunableSpace
(
object
):
"""
A TunableSpace is constructed by the tunable variables.
"""
def
__init__
(
self
):
# Tunable variables for this tunable variables
self
.
_variables
=
{}
# Specific values coresponding to each tunable variable
self
.
_values
=
{}
@
property
def
variables
(
self
):
return
self
.
_variables
@
property
def
values
(
self
):
return
self
.
_values
def
get_value
(
self
,
name
):
if
name
in
self
.
values
:
return
self
.
values
[
name
]
else
:
raise
KeyError
(
"{} does not exist."
.
format
(
name
))
def
set_value
(
self
,
name
,
value
):
if
name
in
self
.
values
:
self
.
values
[
name
]
=
value
else
:
raise
KeyError
(
"{} does not exist."
.
format
(
name
))
def
_exists
(
self
,
name
):
if
name
in
self
.
_variables
:
return
True
return
False
def
_retrieve
(
self
,
tv
):
tv
=
tv
.
__class__
.
from_state
(
tv
.
get_state
())
if
self
.
_exists
(
tv
.
name
):
return
self
.
get_value
(
tv
.
name
)
return
self
.
_register
(
tv
)
def
_register
(
self
,
tv
):
self
.
_variables
[
tv
.
name
]
=
tv
if
tv
.
name
not
in
self
.
values
:
self
.
values
[
tv
.
name
]
=
tv
.
default
return
self
.
values
[
tv
.
name
]
def
__getitem__
(
self
,
name
):
return
self
.
get_value
(
name
)
def
__setitem__
(
self
,
name
,
value
):
self
.
set_value
(
name
,
value
)
def
__contains__
(
self
,
name
):
try
:
self
.
get_value
(
name
)
return
True
except
(
KeyError
,
ValueError
):
return
False
def
fixed
(
self
,
name
,
default
):
tv
=
Fixed
(
name
=
name
,
default
=
default
)
return
self
.
_retrieve
(
tv
)
def
boolean
(
self
,
name
,
default
=
False
):
tv
=
Boolean
(
name
=
name
,
default
=
default
)
return
self
.
_retrieve
(
tv
)
def
choice
(
self
,
name
,
values
,
default
=
None
):
tv
=
Choice
(
name
=
name
,
values
=
values
,
default
=
default
)
return
self
.
_retrieve
(
tv
)
def
int_range
(
self
,
name
,
start
,
stop
,
step
=
1
,
default
=
None
):
tv
=
IntRange
(
name
=
name
,
start
=
start
,
stop
=
stop
,
step
=
step
,
default
=
default
)
return
self
.
_retrieve
(
tv
)
def
float_range
(
self
,
name
,
start
,
stop
,
step
=
None
,
default
=
None
):
tv
=
FloatRange
(
name
=
name
,
start
=
start
,
stop
=
stop
,
step
=
step
,
default
=
default
)
return
self
.
_retrieve
(
tv
)
def
get_state
(
self
):
return
{
"variables"
:
[{
"class_name"
:
v
.
__class__
.
__name__
,
"state"
:
v
.
get_state
()
}
for
v
in
self
.
_variables
.
values
()],
"values"
:
dict
((
k
,
v
)
for
(
k
,
v
)
in
self
.
values
.
items
())
}
@
classmethod
def
from_state
(
cls
,
state
):
ts
=
cls
()
for
v
in
state
[
"variables"
]:
v
=
_deserialize_tunable_variable
(
v
)
ts
.
_variables
[
v
.
name
]
=
v
ts
.
_values
=
dict
((
k
,
v
)
for
(
k
,
v
)
in
state
[
"values"
].
items
())
return
ts
def
_deserialize_tunable_variable
(
state
):
classes
=
(
Boolean
,
Fixed
,
Choice
,
IntRange
,
FloatRange
)
cls_name_to_cls
=
{
cls
.
__name__
:
cls
for
cls
in
classes
}
if
isinstance
(
state
,
classes
):
return
state
if
(
not
isinstance
(
state
,
dict
)
or
"class_name"
not
in
state
or
"state"
not
in
state
):
raise
ValueError
(
"Expect state to be a python dict containing class_name and state as keys, but found {}"
.
format
(
state
))
cls_name
=
state
[
"class_name"
]
cls
=
cls_name_to_cls
[
cls_name
]
if
cls
is
None
:
raise
ValueError
(
"Unknown class name {}"
.
format
(
cls_name
))
cls_state
=
state
[
"state"
]
deserialized_object
=
cls
.
from_state
(
cls_state
)
return
deserialized_object
python/paddle/distributed/auto_parallel/tuner/tunable_variable.py
0 → 100644
浏览文件 @
f84b54eb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
class
TunableVariable
(
object
):
"""
Tunablevariable base class.
"""
def
__init__
(
self
,
name
,
default
=
None
):
self
.
name
=
name
self
.
_default
=
default
@
property
def
default
(
self
):
return
self
.
_default
def
get_state
(
self
):
return
{
"name"
:
self
.
name
,
"default"
:
self
.
default
}
@
classmethod
def
from_state
(
cls
,
state
):
return
cls
(
**
state
)
class
Fixed
(
TunableVariable
):
"""
Fixed variable which cannot be changed.
"""
def
__init__
(
self
,
name
,
default
):
super
(
Fixed
,
self
).
__init__
(
name
=
name
,
default
=
default
)
self
.
name
=
name
if
not
isinstance
(
default
,
(
str
,
int
,
float
,
bool
)):
raise
ValueError
(
"Fixed must be an str, int, float or bool, but found {}"
.
format
(
default
))
self
.
_default
=
default
def
random
(
self
,
seed
=
None
):
return
self
.
_default
def
__repr__
(
self
):
return
"Fixed(name: {}, value: {})"
.
format
(
self
.
name
,
self
.
default
)
class
Boolean
(
TunableVariable
):
"""
Choice between True and False.
"""
def
__init__
(
self
,
name
,
default
=
False
):
super
(
Boolean
,
self
).
__init__
(
name
=
name
,
default
=
default
)
if
default
not
in
{
True
,
False
}:
raise
ValueError
(
"default must be a Python boolean, but got {}"
.
format
(
default
))
def
random
(
self
,
seed
=
None
):
rng
=
np
.
random
.
default_rng
(
seed
)
return
rng
.
choice
((
True
,
False
))
def
__repr__
(
self
):
return
'Boolean(name: "{}", default: {})'
.
format
(
self
.
name
,
self
.
default
)
class
Choice
(
TunableVariable
):
def
__init__
(
self
,
name
,
values
,
default
=
None
):
super
(
Choice
,
self
).
__init__
(
name
=
name
,
default
=
default
)
types
=
set
(
type
(
v
)
for
v
in
values
)
if
len
(
types
)
>
1
:
raise
TypeError
(
"Choice can contain only one type of value, but found values: {} with types: {}."
.
format
(
str
(
values
),
str
(
types
)))
if
isinstance
(
values
[
0
],
str
):
values
=
[
str
(
v
)
for
v
in
values
]
if
default
is
not
None
:
default
=
str
(
default
)
elif
isinstance
(
values
[
0
],
int
):
values
=
[
int
(
v
)
for
v
in
values
]
if
default
is
not
None
:
default
=
int
(
default
)
elif
isinstance
(
values
[
0
],
float
):
values
=
[
float
(
v
)
for
v
in
values
]
if
default
is
not
None
:
default
=
float
(
default
)
elif
isinstance
(
values
[
0
],
bool
):
values
=
[
bool
(
v
)
for
v
in
values
]
if
default
is
not
None
:
default
=
bool
(
default
)
else
:
raise
TypeError
(
"Choice can only contain str, int, float, or boll, but found: {} "
.
format
(
str
(
values
)))
self
.
values
=
values
if
default
is
not
None
and
default
not
in
values
:
raise
ValueError
(
"The default value should be one of the choices {}, but found {}"
.
format
(
values
,
default
))
self
.
_default
=
default
@
property
def
default
(
self
):
if
self
.
_default
is
None
:
if
None
in
self
.
values
:
return
None
return
self
.
values
[
0
]
return
self
.
_default
def
random
(
self
,
seed
=
None
):
rng
=
np
.
random
.
default_rng
(
seed
)
return
rng
.
choice
(
self
.
values
)
def
get_state
(
self
):
state
=
super
(
Choice
,
self
).
get_state
()
state
[
"values"
]
=
self
.
values
return
state
def
__repr__
(
self
):
return
'Choice(name: "{}", values: {}, default: {})'
.
format
(
self
.
name
,
self
.
values
,
self
.
default
)
class
IntRange
(
TunableVariable
):
"""
Integer range.
"""
def
__init__
(
self
,
name
,
start
,
stop
,
step
=
1
,
default
=
None
,
endpoint
=
False
):
super
(
IntRange
,
self
).
__init__
(
name
=
name
,
default
=
default
)
self
.
start
=
self
.
_check_int
(
start
)
self
.
stop
=
self
.
_check_int
(
stop
)
self
.
step
=
self
.
_check_int
(
step
)
self
.
_default
=
default
self
.
endpoint
=
endpoint
@
property
def
default
(
self
):
if
self
.
_default
is
not
None
:
return
self
.
_default
return
self
.
start
def
random
(
self
,
seed
=
None
):
rng
=
np
.
random
.
default_rng
(
seed
)
value
=
(
self
.
stop
-
self
.
start
)
*
rng
.
random
()
+
self
.
start
if
self
.
step
is
not
None
:
if
self
.
endpoint
:
values
=
np
.
arange
(
self
.
start
,
self
.
stop
+
1e-7
,
step
=
self
.
step
)
else
:
values
=
np
.
arange
(
self
.
start
,
self
.
stop
,
step
=
self
.
step
)
closest_index
=
np
.
abs
(
values
-
value
).
argmin
()
value
=
values
[
closest_index
]
return
int
(
value
)
def
get_state
(
self
):
state
=
super
(
IntRange
,
self
).
get_state
()
state
[
"start"
]
=
self
.
start
state
[
"stop"
]
=
self
.
stop
state
[
"step"
]
=
self
.
step
state
[
"default"
]
=
self
.
_default
return
state
def
_check_int
(
self
,
val
):
int_val
=
int
(
val
)
if
int_val
!=
val
:
raise
ValueError
(
"Expects val is an int, but found: {}."
.
format
(
str
(
val
)))
return
int_val
def
__repr__
(
self
):
return
"IntRange(name: {}, start: {}, stop: {}, step: {}, default: {})"
.
format
(
self
.
name
,
self
.
start
,
self
.
stop
,
self
.
step
,
self
.
default
)
class
FloatRange
(
TunableVariable
):
"""
Float range.
"""
def
__init__
(
self
,
name
,
start
,
stop
,
step
=
None
,
default
=
None
,
endpoint
=
False
):
super
(
FloatRange
,
self
).
__init__
(
name
=
name
,
default
=
default
)
self
.
stop
=
float
(
stop
)
self
.
start
=
float
(
start
)
if
step
is
not
None
:
self
.
step
=
float
(
step
)
else
:
self
.
step
=
None
self
.
_default
=
default
self
.
endpoint
=
endpoint
@
property
def
default
(
self
):
if
self
.
_default
is
not
None
:
return
self
.
_default
return
self
.
start
def
random
(
self
,
seed
=
None
):
rng
=
np
.
random
.
default_rng
(
seed
)
value
=
(
self
.
stop
-
self
.
start
)
*
rng
.
random
()
+
self
.
start
if
self
.
step
is
not
None
:
if
self
.
endpoint
:
values
=
np
.
arange
(
self
.
start
,
self
.
stop
+
1e-7
,
step
=
self
.
step
)
else
:
values
=
np
.
arange
(
self
.
start
,
self
.
stop
,
step
=
self
.
step
)
closest_index
=
np
.
abs
(
values
-
value
).
argmin
()
value
=
values
[
closest_index
]
return
value
def
get_state
(
self
):
state
=
super
(
FloatRange
,
self
).
get_state
()
state
[
"start"
]
=
self
.
start
state
[
"stop"
]
=
self
.
stop
state
[
"step"
]
=
self
.
step
state
[
"endpoint"
]
=
self
.
endpoint
return
state
def
__repr__
(
self
):
return
"FloatRange(name: {}, start: {}, stop: {}, step: {}, default: {}, endpoint: {})"
.
format
(
self
.
name
,
self
.
start
,
self
.
stop
,
self
.
step
,
self
.
default
,
self
.
endpoint
)
python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_space.py
0 → 100644
浏览文件 @
f84b54eb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
paddle.distributed.auto_parallel.tuner
import
tunable_space
as
ts
class
TestTunableSpace
(
unittest
.
TestCase
):
def
test_fixed
(
self
):
space
=
ts
.
TunableSpace
()
fixed
=
space
.
fixed
(
"fixed"
,
default
=
4
)
self
.
assertEqual
(
space
.
values
[
"fixed"
],
4
)
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"fixed"
].
name
,
"fixed"
)
space
.
values
[
"fixed"
]
=
2
self
.
assertEqual
(
space
.
get_value
(
"fixed"
),
2
)
self
.
assertEqual
(
space
.
values
,
{
"fixed"
:
2
})
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"fixed"
].
name
,
"fixed"
)
def
test_boolean
(
self
):
space
=
ts
.
TunableSpace
()
boolean
=
space
.
boolean
(
"boolean"
)
self
.
assertEqual
(
space
.
values
[
"boolean"
],
False
)
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"boolean"
].
name
,
"boolean"
)
space
.
values
[
"boolean"
]
=
True
self
.
assertEqual
(
space
.
get_value
(
"boolean"
),
True
)
self
.
assertEqual
(
space
.
values
,
{
"boolean"
:
True
})
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"boolean"
].
name
,
"boolean"
)
def
test_choice
(
self
):
space
=
ts
.
TunableSpace
()
choice
=
space
.
choice
(
"choice"
,
[
1
,
2
,
3
,
4
],
default
=
4
)
self
.
assertEqual
(
space
.
values
[
"choice"
],
4
)
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"choice"
].
name
,
"choice"
)
space
.
values
[
"choice"
]
=
2
self
.
assertEqual
(
space
.
get_value
(
"choice"
),
2
)
self
.
assertEqual
(
space
.
values
,
{
"choice"
:
2
})
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"choice"
].
name
,
"choice"
)
def
test_int_range
(
self
):
space
=
ts
.
TunableSpace
()
int_range
=
space
.
int_range
(
"int_range"
,
start
=
1
,
stop
=
4
,
default
=
2
)
self
.
assertEqual
(
space
.
values
[
"int_range"
],
2
)
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"int_range"
].
name
,
"int_range"
)
space
.
values
[
"int_range"
]
=
3
self
.
assertEqual
(
space
.
get_value
(
"int_range"
),
3
)
self
.
assertEqual
(
space
.
values
,
{
"int_range"
:
3
})
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"int_range"
].
name
,
"int_range"
)
def
test_float_range
(
self
):
space
=
ts
.
TunableSpace
()
float_range
=
space
.
float_range
(
"float_range"
,
start
=
0.4
,
stop
=
4.4
,
default
=
2.0
)
self
.
assertEqual
(
space
.
values
[
"float_range"
],
2.0
)
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"float_range"
].
name
,
"float_range"
)
space
.
values
[
"float_range"
]
=
3.0
self
.
assertEqual
(
space
.
get_value
(
"float_range"
),
3.0
)
self
.
assertEqual
(
space
.
values
,
{
"float_range"
:
3.0
})
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"float_range"
].
name
,
"float_range"
)
def
test_varibles
(
self
):
space
=
ts
.
TunableSpace
()
choice
=
space
.
choice
(
"choice"
,
[
1
,
2
,
3
,
4
],
default
=
4
)
self
.
assertEqual
(
space
.
values
[
"choice"
],
4
)
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"choice"
].
name
,
"choice"
)
int_range
=
space
.
int_range
(
"int_range"
,
start
=
1
,
stop
=
4
,
default
=
2
)
self
.
assertEqual
(
space
.
values
[
"int_range"
],
2
)
self
.
assertEqual
(
len
(
space
.
variables
),
2
)
self
.
assertEqual
(
space
.
variables
[
"int_range"
].
name
,
"int_range"
)
def
test_not_populated_variable
(
self
):
space
=
ts
.
TunableSpace
()
choice
=
space
.
choice
(
"choice"
,
[
1
,
2
,
3
,
4
],
default
=
2
)
self
.
assertEqual
(
choice
,
2
)
def
test_populated_variable
(
self
):
space
=
ts
.
TunableSpace
()
space
.
values
[
"choice"
]
=
2
choice
=
space
.
choice
(
"choice"
,
[
1
,
2
,
3
,
4
],
default
=
4
)
self
.
assertEqual
(
choice
,
2
)
space
[
"choice"
]
=
3
self
.
assertNotEqual
(
space
.
values
[
"choice"
],
2
)
self
.
assertEqual
(
space
.
values
[
"choice"
],
3
)
def
test_state
(
self
):
space
=
ts
.
TunableSpace
()
choice
=
space
.
choice
(
"choice"
,
[
1
,
2
,
3
,
4
],
default
=
4
)
int_range
=
space
.
int_range
(
"int_range"
,
start
=
1
,
stop
=
4
,
default
=
2
)
new_space
=
space
.
from_state
(
space
.
get_state
())
self
.
assertEqual
(
new_space
.
get_value
(
"choice"
),
4
)
self
.
assertEqual
(
new_space
.
get_value
(
"int_range"
),
2
)
self
.
assertEqual
(
len
(
new_space
.
variables
),
2
)
self
.
assertEqual
(
len
(
new_space
.
values
),
2
)
self
.
assertEqual
(
new_space
.
variables
[
"choice"
].
name
,
"choice"
)
self
.
assertEqual
(
new_space
.
variables
[
"choice"
].
default
,
4
)
self
.
assertEqual
(
new_space
.
variables
[
"choice"
].
values
,
[
1
,
2
,
3
,
4
])
self
.
assertEqual
(
new_space
.
variables
[
"int_range"
].
name
,
"int_range"
)
self
.
assertEqual
(
new_space
.
variables
[
"int_range"
].
default
,
2
)
self
.
assertEqual
(
new_space
.
variables
[
"int_range"
].
start
,
1
)
self
.
assertEqual
(
new_space
.
variables
[
"int_range"
].
stop
,
4
)
self
.
assertEqual
(
new_space
.
variables
[
"int_range"
].
step
,
1
)
self
.
assertEqual
(
new_space
.
variables
[
"int_range"
].
endpoint
,
False
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_variable.py
0 → 100644
浏览文件 @
f84b54eb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
paddle.distributed.auto_parallel.tuner
import
tunable_variable
as
tv
class
TestTunableVariable
(
unittest
.
TestCase
):
def
test_fixed
(
self
):
fixed
=
tv
.
Fixed
(
"fixed"
,
True
)
fixed
=
tv
.
Fixed
.
from_state
(
fixed
.
get_state
())
self
.
assertEqual
(
fixed
.
default
,
True
)
self
.
assertEqual
(
fixed
.
random
(),
True
)
fixed
=
tv
.
Fixed
(
"fixed"
,
1
)
fixed
=
tv
.
Fixed
.
from_state
(
fixed
.
get_state
())
self
.
assertEqual
(
fixed
.
default
,
1
)
self
.
assertEqual
(
fixed
.
random
(),
1
)
def
test_boolean
(
self
):
boolean
=
tv
.
Boolean
(
"bool"
)
boolean
=
tv
.
Boolean
.
from_state
(
boolean
.
get_state
())
self
.
assertEqual
(
boolean
.
default
,
False
)
self
.
assertIn
(
boolean
.
random
(),
[
True
,
False
])
self
.
assertIn
(
boolean
.
random
(
1234
),
[
True
,
False
])
boolean
=
tv
.
Boolean
(
"bool"
,
True
)
boolean
=
tv
.
Boolean
.
from_state
(
boolean
.
get_state
())
self
.
assertEqual
(
boolean
.
default
,
True
)
self
.
assertIn
(
boolean
.
random
(),
[
True
,
False
])
self
.
assertIn
(
boolean
.
random
(
1234
),
[
True
,
False
])
def
test_choice
(
self
):
choice
=
tv
.
Choice
(
"choice"
,
[
1
,
2
,
3
,
4
])
choice
=
tv
.
Choice
.
from_state
(
choice
.
get_state
())
self
.
assertEqual
(
choice
.
default
,
1
)
self
.
assertIn
(
choice
.
random
(),
[
1
,
2
,
3
,
4
])
self
.
assertIn
(
choice
.
random
(
1234
),
[
1
,
2
,
3
,
4
])
choice
=
tv
.
Choice
(
"choice"
,
[
1
,
2
,
3
,
4
],
default
=
2
)
choice
=
tv
.
Choice
.
from_state
(
choice
.
get_state
())
self
.
assertEqual
(
choice
.
default
,
2
)
self
.
assertIn
(
choice
.
random
(),
[
1
,
2
,
3
,
4
])
self
.
assertIn
(
choice
.
random
(
1234
),
[
1
,
2
,
3
,
4
])
def
test_int_range
(
self
):
int_range
=
tv
.
IntRange
(
"int_range"
,
start
=
1
,
stop
=
4
,
default
=
2
)
int_range
=
tv
.
IntRange
.
from_state
(
int_range
.
get_state
())
self
.
assertEqual
(
int_range
.
default
,
2
)
self
.
assertIn
(
int_range
.
random
(),
[
1
,
2
,
3
,
4
])
self
.
assertIn
(
int_range
.
random
(
1234
),
[
1
,
2
,
3
,
4
])
self
.
assertNotEqual
(
int_range
.
default
,
4
)
int_range
=
tv
.
IntRange
(
"int_range"
,
start
=
1
,
stop
=
8
,
step
=
2
,
default
=
3
,
endpoint
=
True
)
int_range
=
tv
.
IntRange
.
from_state
(
int_range
.
get_state
())
self
.
assertEqual
(
int_range
.
default
,
3
)
self
.
assertIn
(
int_range
.
random
(),
[
1
,
3
,
5
,
7
])
self
.
assertIn
(
int_range
.
random
(
1234
),
[
1
,
3
,
5
,
7
])
self
.
assertNotEqual
(
int_range
.
default
,
2
)
def
test_float_range
(
self
):
float_range
=
tv
.
FloatRange
(
"float_range"
,
start
=
0.4
,
stop
=
4.4
,
default
=
2.0
)
float_range
=
tv
.
FloatRange
.
from_state
(
float_range
.
get_state
())
self
.
assertEqual
(
float_range
.
default
,
2.0
)
self
.
assertGreater
(
float_range
.
random
(),
0.4
)
self
.
assertLess
(
float_range
.
random
(
1234
),
4.4
)
self
.
assertNotAlmostEqual
(
float_range
.
random
(),
1
)
self
.
assertNotAlmostEqual
(
float_range
.
random
(),
4.4
)
float_range
=
tv
.
FloatRange
(
"float_range"
,
start
=
0.4
,
stop
=
8.4
,
step
=
2.0
,
default
=
3.0
,
endpoint
=
True
)
float_range
=
tv
.
FloatRange
.
from_state
(
float_range
.
get_state
())
self
.
assertEqual
(
float_range
.
default
,
3.0
)
self
.
assertGreater
(
float_range
.
random
(),
0.4
)
self
.
assertLessEqual
(
float_range
.
random
(
1234
),
8.4
)
self
.
assertNotAlmostEqual
(
float_range
.
random
(),
2
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/setup.py.in
浏览文件 @
f84b54eb
...
@@ -300,6 +300,7 @@ packages=['paddle',
...
@@ -300,6 +300,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_parallel.parallel_layers',
'paddle.distributed.fleet.meta_parallel.parallel_layers',
'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel.operators',
'paddle.distributed.auto_parallel.operators',
'paddle.distributed.auto_parallel.tuner',
'paddle.distributed.passes',
'paddle.distributed.passes',
'paddle.framework',
'paddle.framework',
'paddle.jit',
'paddle.jit',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录