Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f84b54eb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f84b54eb
编写于
3月 15, 2022
作者:
Y
Yulong Ao
提交者:
GitHub
3月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Auto parallel] Redesign the tuner for auto parallel (#40121)
* [Auto Parallel] Redesign the tunner for Auto Parallel
上级
0c333543
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
680 addition
and
0 deletion
+680
-0
python/paddle/distributed/auto_parallel/tuner/__init__.py
python/paddle/distributed/auto_parallel/tuner/__init__.py
+13
-0
python/paddle/distributed/auto_parallel/tuner/storable.py
python/paddle/distributed/auto_parallel/tuner/storable.py
+36
-0
python/paddle/distributed/auto_parallel/tuner/tunable_space.py
...n/paddle/distributed/auto_parallel/tuner/tunable_space.py
+151
-0
python/paddle/distributed/auto_parallel/tuner/tunable_variable.py
...addle/distributed/auto_parallel/tuner/tunable_variable.py
+242
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_space.py
...fluid/tests/unittests/auto_parallel/test_tunable_space.py
+138
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_variable.py
...id/tests/unittests/auto_parallel/test_tunable_variable.py
+99
-0
python/setup.py.in
python/setup.py.in
+1
-0
未找到文件。
python/paddle/distributed/auto_parallel/tuner/__init__.py
0 → 100644
浏览文件 @
f84b54eb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python/paddle/distributed/auto_parallel/tuner/storable.py
0 → 100644
浏览文件 @
f84b54eb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
class
Storable
(
object
):
def
get_state
(
self
):
raise
NotImplementedError
def
set_state
(
self
,
state
):
raise
NotImplementedError
def
save
(
self
,
path
):
state
=
self
.
get_state
()
state_json
=
json
.
dumps
(
state
)
with
open
(
path
,
"w"
)
as
f
:
f
.
write
(
state_json
)
return
str
(
path
)
def
load
(
self
,
path
):
with
open
(
path
,
"r"
)
as
f
:
state_data
=
f
.
read
()
state
=
json
.
loads
(
state_data
)
self
.
set_state
(
state
)
python/paddle/distributed/auto_parallel/tuner/tunable_space.py
0 → 100644
浏览文件 @
f84b54eb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
collections
import
contextlib
import
copy
import
math
import
random
import
numpy
as
np
from
.tunable_variable
import
Boolean
from
.tunable_variable
import
Fixed
from
.tunable_variable
import
Choice
from
.tunable_variable
import
IntRange
from
.tunable_variable
import
FloatRange
class
TunableSpace
(
object
):
"""
A TunableSpace is constructed by the tunable variables.
"""
def
__init__
(
self
):
# Tunable variables for this tunable variables
self
.
_variables
=
{}
# Specific values coresponding to each tunable variable
self
.
_values
=
{}
@
property
def
variables
(
self
):
return
self
.
_variables
@
property
def
values
(
self
):
return
self
.
_values
def
get_value
(
self
,
name
):
if
name
in
self
.
values
:
return
self
.
values
[
name
]
else
:
raise
KeyError
(
"{} does not exist."
.
format
(
name
))
def
set_value
(
self
,
name
,
value
):
if
name
in
self
.
values
:
self
.
values
[
name
]
=
value
else
:
raise
KeyError
(
"{} does not exist."
.
format
(
name
))
def
_exists
(
self
,
name
):
if
name
in
self
.
_variables
:
return
True
return
False
def
_retrieve
(
self
,
tv
):
tv
=
tv
.
__class__
.
from_state
(
tv
.
get_state
())
if
self
.
_exists
(
tv
.
name
):
return
self
.
get_value
(
tv
.
name
)
return
self
.
_register
(
tv
)
def
_register
(
self
,
tv
):
self
.
_variables
[
tv
.
name
]
=
tv
if
tv
.
name
not
in
self
.
values
:
self
.
values
[
tv
.
name
]
=
tv
.
default
return
self
.
values
[
tv
.
name
]
def
__getitem__
(
self
,
name
):
return
self
.
get_value
(
name
)
def
__setitem__
(
self
,
name
,
value
):
self
.
set_value
(
name
,
value
)
def
__contains__
(
self
,
name
):
try
:
self
.
get_value
(
name
)
return
True
except
(
KeyError
,
ValueError
):
return
False
def
fixed
(
self
,
name
,
default
):
tv
=
Fixed
(
name
=
name
,
default
=
default
)
return
self
.
_retrieve
(
tv
)
def
boolean
(
self
,
name
,
default
=
False
):
tv
=
Boolean
(
name
=
name
,
default
=
default
)
return
self
.
_retrieve
(
tv
)
def
choice
(
self
,
name
,
values
,
default
=
None
):
tv
=
Choice
(
name
=
name
,
values
=
values
,
default
=
default
)
return
self
.
_retrieve
(
tv
)
def
int_range
(
self
,
name
,
start
,
stop
,
step
=
1
,
default
=
None
):
tv
=
IntRange
(
name
=
name
,
start
=
start
,
stop
=
stop
,
step
=
step
,
default
=
default
)
return
self
.
_retrieve
(
tv
)
def
float_range
(
self
,
name
,
start
,
stop
,
step
=
None
,
default
=
None
):
tv
=
FloatRange
(
name
=
name
,
start
=
start
,
stop
=
stop
,
step
=
step
,
default
=
default
)
return
self
.
_retrieve
(
tv
)
def
get_state
(
self
):
return
{
"variables"
:
[{
"class_name"
:
v
.
__class__
.
__name__
,
"state"
:
v
.
get_state
()
}
for
v
in
self
.
_variables
.
values
()],
"values"
:
dict
((
k
,
v
)
for
(
k
,
v
)
in
self
.
values
.
items
())
}
@
classmethod
def
from_state
(
cls
,
state
):
ts
=
cls
()
for
v
in
state
[
"variables"
]:
v
=
_deserialize_tunable_variable
(
v
)
ts
.
_variables
[
v
.
name
]
=
v
ts
.
_values
=
dict
((
k
,
v
)
for
(
k
,
v
)
in
state
[
"values"
].
items
())
return
ts
def
_deserialize_tunable_variable
(
state
):
classes
=
(
Boolean
,
Fixed
,
Choice
,
IntRange
,
FloatRange
)
cls_name_to_cls
=
{
cls
.
__name__
:
cls
for
cls
in
classes
}
if
isinstance
(
state
,
classes
):
return
state
if
(
not
isinstance
(
state
,
dict
)
or
"class_name"
not
in
state
or
"state"
not
in
state
):
raise
ValueError
(
"Expect state to be a python dict containing class_name and state as keys, but found {}"
.
format
(
state
))
cls_name
=
state
[
"class_name"
]
cls
=
cls_name_to_cls
[
cls_name
]
if
cls
is
None
:
raise
ValueError
(
"Unknown class name {}"
.
format
(
cls_name
))
cls_state
=
state
[
"state"
]
deserialized_object
=
cls
.
from_state
(
cls_state
)
return
deserialized_object
python/paddle/distributed/auto_parallel/tuner/tunable_variable.py
0 → 100644
浏览文件 @
f84b54eb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
class
TunableVariable
(
object
):
"""
Tunablevariable base class.
"""
def
__init__
(
self
,
name
,
default
=
None
):
self
.
name
=
name
self
.
_default
=
default
@
property
def
default
(
self
):
return
self
.
_default
def
get_state
(
self
):
return
{
"name"
:
self
.
name
,
"default"
:
self
.
default
}
@
classmethod
def
from_state
(
cls
,
state
):
return
cls
(
**
state
)
class
Fixed
(
TunableVariable
):
"""
Fixed variable which cannot be changed.
"""
def
__init__
(
self
,
name
,
default
):
super
(
Fixed
,
self
).
__init__
(
name
=
name
,
default
=
default
)
self
.
name
=
name
if
not
isinstance
(
default
,
(
str
,
int
,
float
,
bool
)):
raise
ValueError
(
"Fixed must be an str, int, float or bool, but found {}"
.
format
(
default
))
self
.
_default
=
default
def
random
(
self
,
seed
=
None
):
return
self
.
_default
def
__repr__
(
self
):
return
"Fixed(name: {}, value: {})"
.
format
(
self
.
name
,
self
.
default
)
class
Boolean
(
TunableVariable
):
"""
Choice between True and False.
"""
def
__init__
(
self
,
name
,
default
=
False
):
super
(
Boolean
,
self
).
__init__
(
name
=
name
,
default
=
default
)
if
default
not
in
{
True
,
False
}:
raise
ValueError
(
"default must be a Python boolean, but got {}"
.
format
(
default
))
def
random
(
self
,
seed
=
None
):
rng
=
np
.
random
.
default_rng
(
seed
)
return
rng
.
choice
((
True
,
False
))
def
__repr__
(
self
):
return
'Boolean(name: "{}", default: {})'
.
format
(
self
.
name
,
self
.
default
)
class
Choice
(
TunableVariable
):
def
__init__
(
self
,
name
,
values
,
default
=
None
):
super
(
Choice
,
self
).
__init__
(
name
=
name
,
default
=
default
)
types
=
set
(
type
(
v
)
for
v
in
values
)
if
len
(
types
)
>
1
:
raise
TypeError
(
"Choice can contain only one type of value, but found values: {} with types: {}."
.
format
(
str
(
values
),
str
(
types
)))
if
isinstance
(
values
[
0
],
str
):
values
=
[
str
(
v
)
for
v
in
values
]
if
default
is
not
None
:
default
=
str
(
default
)
elif
isinstance
(
values
[
0
],
int
):
values
=
[
int
(
v
)
for
v
in
values
]
if
default
is
not
None
:
default
=
int
(
default
)
elif
isinstance
(
values
[
0
],
float
):
values
=
[
float
(
v
)
for
v
in
values
]
if
default
is
not
None
:
default
=
float
(
default
)
elif
isinstance
(
values
[
0
],
bool
):
values
=
[
bool
(
v
)
for
v
in
values
]
if
default
is
not
None
:
default
=
bool
(
default
)
else
:
raise
TypeError
(
"Choice can only contain str, int, float, or boll, but found: {} "
.
format
(
str
(
values
)))
self
.
values
=
values
if
default
is
not
None
and
default
not
in
values
:
raise
ValueError
(
"The default value should be one of the choices {}, but found {}"
.
format
(
values
,
default
))
self
.
_default
=
default
@
property
def
default
(
self
):
if
self
.
_default
is
None
:
if
None
in
self
.
values
:
return
None
return
self
.
values
[
0
]
return
self
.
_default
def
random
(
self
,
seed
=
None
):
rng
=
np
.
random
.
default_rng
(
seed
)
return
rng
.
choice
(
self
.
values
)
def
get_state
(
self
):
state
=
super
(
Choice
,
self
).
get_state
()
state
[
"values"
]
=
self
.
values
return
state
def
__repr__
(
self
):
return
'Choice(name: "{}", values: {}, default: {})'
.
format
(
self
.
name
,
self
.
values
,
self
.
default
)
class
IntRange
(
TunableVariable
):
"""
Integer range.
"""
def
__init__
(
self
,
name
,
start
,
stop
,
step
=
1
,
default
=
None
,
endpoint
=
False
):
super
(
IntRange
,
self
).
__init__
(
name
=
name
,
default
=
default
)
self
.
start
=
self
.
_check_int
(
start
)
self
.
stop
=
self
.
_check_int
(
stop
)
self
.
step
=
self
.
_check_int
(
step
)
self
.
_default
=
default
self
.
endpoint
=
endpoint
@
property
def
default
(
self
):
if
self
.
_default
is
not
None
:
return
self
.
_default
return
self
.
start
def
random
(
self
,
seed
=
None
):
rng
=
np
.
random
.
default_rng
(
seed
)
value
=
(
self
.
stop
-
self
.
start
)
*
rng
.
random
()
+
self
.
start
if
self
.
step
is
not
None
:
if
self
.
endpoint
:
values
=
np
.
arange
(
self
.
start
,
self
.
stop
+
1e-7
,
step
=
self
.
step
)
else
:
values
=
np
.
arange
(
self
.
start
,
self
.
stop
,
step
=
self
.
step
)
closest_index
=
np
.
abs
(
values
-
value
).
argmin
()
value
=
values
[
closest_index
]
return
int
(
value
)
def
get_state
(
self
):
state
=
super
(
IntRange
,
self
).
get_state
()
state
[
"start"
]
=
self
.
start
state
[
"stop"
]
=
self
.
stop
state
[
"step"
]
=
self
.
step
state
[
"default"
]
=
self
.
_default
return
state
def
_check_int
(
self
,
val
):
int_val
=
int
(
val
)
if
int_val
!=
val
:
raise
ValueError
(
"Expects val is an int, but found: {}."
.
format
(
str
(
val
)))
return
int_val
def
__repr__
(
self
):
return
"IntRange(name: {}, start: {}, stop: {}, step: {}, default: {})"
.
format
(
self
.
name
,
self
.
start
,
self
.
stop
,
self
.
step
,
self
.
default
)
class
FloatRange
(
TunableVariable
):
"""
Float range.
"""
def
__init__
(
self
,
name
,
start
,
stop
,
step
=
None
,
default
=
None
,
endpoint
=
False
):
super
(
FloatRange
,
self
).
__init__
(
name
=
name
,
default
=
default
)
self
.
stop
=
float
(
stop
)
self
.
start
=
float
(
start
)
if
step
is
not
None
:
self
.
step
=
float
(
step
)
else
:
self
.
step
=
None
self
.
_default
=
default
self
.
endpoint
=
endpoint
@
property
def
default
(
self
):
if
self
.
_default
is
not
None
:
return
self
.
_default
return
self
.
start
def
random
(
self
,
seed
=
None
):
rng
=
np
.
random
.
default_rng
(
seed
)
value
=
(
self
.
stop
-
self
.
start
)
*
rng
.
random
()
+
self
.
start
if
self
.
step
is
not
None
:
if
self
.
endpoint
:
values
=
np
.
arange
(
self
.
start
,
self
.
stop
+
1e-7
,
step
=
self
.
step
)
else
:
values
=
np
.
arange
(
self
.
start
,
self
.
stop
,
step
=
self
.
step
)
closest_index
=
np
.
abs
(
values
-
value
).
argmin
()
value
=
values
[
closest_index
]
return
value
def
get_state
(
self
):
state
=
super
(
FloatRange
,
self
).
get_state
()
state
[
"start"
]
=
self
.
start
state
[
"stop"
]
=
self
.
stop
state
[
"step"
]
=
self
.
step
state
[
"endpoint"
]
=
self
.
endpoint
return
state
def
__repr__
(
self
):
return
"FloatRange(name: {}, start: {}, stop: {}, step: {}, default: {}, endpoint: {})"
.
format
(
self
.
name
,
self
.
start
,
self
.
stop
,
self
.
step
,
self
.
default
,
self
.
endpoint
)
python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_space.py
0 → 100644
浏览文件 @
f84b54eb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
paddle.distributed.auto_parallel.tuner
import
tunable_space
as
ts
class
TestTunableSpace
(
unittest
.
TestCase
):
def
test_fixed
(
self
):
space
=
ts
.
TunableSpace
()
fixed
=
space
.
fixed
(
"fixed"
,
default
=
4
)
self
.
assertEqual
(
space
.
values
[
"fixed"
],
4
)
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"fixed"
].
name
,
"fixed"
)
space
.
values
[
"fixed"
]
=
2
self
.
assertEqual
(
space
.
get_value
(
"fixed"
),
2
)
self
.
assertEqual
(
space
.
values
,
{
"fixed"
:
2
})
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"fixed"
].
name
,
"fixed"
)
def
test_boolean
(
self
):
space
=
ts
.
TunableSpace
()
boolean
=
space
.
boolean
(
"boolean"
)
self
.
assertEqual
(
space
.
values
[
"boolean"
],
False
)
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"boolean"
].
name
,
"boolean"
)
space
.
values
[
"boolean"
]
=
True
self
.
assertEqual
(
space
.
get_value
(
"boolean"
),
True
)
self
.
assertEqual
(
space
.
values
,
{
"boolean"
:
True
})
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"boolean"
].
name
,
"boolean"
)
def
test_choice
(
self
):
space
=
ts
.
TunableSpace
()
choice
=
space
.
choice
(
"choice"
,
[
1
,
2
,
3
,
4
],
default
=
4
)
self
.
assertEqual
(
space
.
values
[
"choice"
],
4
)
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"choice"
].
name
,
"choice"
)
space
.
values
[
"choice"
]
=
2
self
.
assertEqual
(
space
.
get_value
(
"choice"
),
2
)
self
.
assertEqual
(
space
.
values
,
{
"choice"
:
2
})
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"choice"
].
name
,
"choice"
)
def
test_int_range
(
self
):
space
=
ts
.
TunableSpace
()
int_range
=
space
.
int_range
(
"int_range"
,
start
=
1
,
stop
=
4
,
default
=
2
)
self
.
assertEqual
(
space
.
values
[
"int_range"
],
2
)
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"int_range"
].
name
,
"int_range"
)
space
.
values
[
"int_range"
]
=
3
self
.
assertEqual
(
space
.
get_value
(
"int_range"
),
3
)
self
.
assertEqual
(
space
.
values
,
{
"int_range"
:
3
})
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"int_range"
].
name
,
"int_range"
)
def
test_float_range
(
self
):
space
=
ts
.
TunableSpace
()
float_range
=
space
.
float_range
(
"float_range"
,
start
=
0.4
,
stop
=
4.4
,
default
=
2.0
)
self
.
assertEqual
(
space
.
values
[
"float_range"
],
2.0
)
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"float_range"
].
name
,
"float_range"
)
space
.
values
[
"float_range"
]
=
3.0
self
.
assertEqual
(
space
.
get_value
(
"float_range"
),
3.0
)
self
.
assertEqual
(
space
.
values
,
{
"float_range"
:
3.0
})
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"float_range"
].
name
,
"float_range"
)
def
test_varibles
(
self
):
space
=
ts
.
TunableSpace
()
choice
=
space
.
choice
(
"choice"
,
[
1
,
2
,
3
,
4
],
default
=
4
)
self
.
assertEqual
(
space
.
values
[
"choice"
],
4
)
self
.
assertEqual
(
len
(
space
.
variables
),
1
)
self
.
assertEqual
(
space
.
variables
[
"choice"
].
name
,
"choice"
)
int_range
=
space
.
int_range
(
"int_range"
,
start
=
1
,
stop
=
4
,
default
=
2
)
self
.
assertEqual
(
space
.
values
[
"int_range"
],
2
)
self
.
assertEqual
(
len
(
space
.
variables
),
2
)
self
.
assertEqual
(
space
.
variables
[
"int_range"
].
name
,
"int_range"
)
def
test_not_populated_variable
(
self
):
space
=
ts
.
TunableSpace
()
choice
=
space
.
choice
(
"choice"
,
[
1
,
2
,
3
,
4
],
default
=
2
)
self
.
assertEqual
(
choice
,
2
)
def
test_populated_variable
(
self
):
space
=
ts
.
TunableSpace
()
space
.
values
[
"choice"
]
=
2
choice
=
space
.
choice
(
"choice"
,
[
1
,
2
,
3
,
4
],
default
=
4
)
self
.
assertEqual
(
choice
,
2
)
space
[
"choice"
]
=
3
self
.
assertNotEqual
(
space
.
values
[
"choice"
],
2
)
self
.
assertEqual
(
space
.
values
[
"choice"
],
3
)
def
test_state
(
self
):
space
=
ts
.
TunableSpace
()
choice
=
space
.
choice
(
"choice"
,
[
1
,
2
,
3
,
4
],
default
=
4
)
int_range
=
space
.
int_range
(
"int_range"
,
start
=
1
,
stop
=
4
,
default
=
2
)
new_space
=
space
.
from_state
(
space
.
get_state
())
self
.
assertEqual
(
new_space
.
get_value
(
"choice"
),
4
)
self
.
assertEqual
(
new_space
.
get_value
(
"int_range"
),
2
)
self
.
assertEqual
(
len
(
new_space
.
variables
),
2
)
self
.
assertEqual
(
len
(
new_space
.
values
),
2
)
self
.
assertEqual
(
new_space
.
variables
[
"choice"
].
name
,
"choice"
)
self
.
assertEqual
(
new_space
.
variables
[
"choice"
].
default
,
4
)
self
.
assertEqual
(
new_space
.
variables
[
"choice"
].
values
,
[
1
,
2
,
3
,
4
])
self
.
assertEqual
(
new_space
.
variables
[
"int_range"
].
name
,
"int_range"
)
self
.
assertEqual
(
new_space
.
variables
[
"int_range"
].
default
,
2
)
self
.
assertEqual
(
new_space
.
variables
[
"int_range"
].
start
,
1
)
self
.
assertEqual
(
new_space
.
variables
[
"int_range"
].
stop
,
4
)
self
.
assertEqual
(
new_space
.
variables
[
"int_range"
].
step
,
1
)
self
.
assertEqual
(
new_space
.
variables
[
"int_range"
].
endpoint
,
False
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_tunable_variable.py
0 → 100644
浏览文件 @
f84b54eb
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
paddle.distributed.auto_parallel.tuner
import
tunable_variable
as
tv
class
TestTunableVariable
(
unittest
.
TestCase
):
def
test_fixed
(
self
):
fixed
=
tv
.
Fixed
(
"fixed"
,
True
)
fixed
=
tv
.
Fixed
.
from_state
(
fixed
.
get_state
())
self
.
assertEqual
(
fixed
.
default
,
True
)
self
.
assertEqual
(
fixed
.
random
(),
True
)
fixed
=
tv
.
Fixed
(
"fixed"
,
1
)
fixed
=
tv
.
Fixed
.
from_state
(
fixed
.
get_state
())
self
.
assertEqual
(
fixed
.
default
,
1
)
self
.
assertEqual
(
fixed
.
random
(),
1
)
def
test_boolean
(
self
):
boolean
=
tv
.
Boolean
(
"bool"
)
boolean
=
tv
.
Boolean
.
from_state
(
boolean
.
get_state
())
self
.
assertEqual
(
boolean
.
default
,
False
)
self
.
assertIn
(
boolean
.
random
(),
[
True
,
False
])
self
.
assertIn
(
boolean
.
random
(
1234
),
[
True
,
False
])
boolean
=
tv
.
Boolean
(
"bool"
,
True
)
boolean
=
tv
.
Boolean
.
from_state
(
boolean
.
get_state
())
self
.
assertEqual
(
boolean
.
default
,
True
)
self
.
assertIn
(
boolean
.
random
(),
[
True
,
False
])
self
.
assertIn
(
boolean
.
random
(
1234
),
[
True
,
False
])
def
test_choice
(
self
):
choice
=
tv
.
Choice
(
"choice"
,
[
1
,
2
,
3
,
4
])
choice
=
tv
.
Choice
.
from_state
(
choice
.
get_state
())
self
.
assertEqual
(
choice
.
default
,
1
)
self
.
assertIn
(
choice
.
random
(),
[
1
,
2
,
3
,
4
])
self
.
assertIn
(
choice
.
random
(
1234
),
[
1
,
2
,
3
,
4
])
choice
=
tv
.
Choice
(
"choice"
,
[
1
,
2
,
3
,
4
],
default
=
2
)
choice
=
tv
.
Choice
.
from_state
(
choice
.
get_state
())
self
.
assertEqual
(
choice
.
default
,
2
)
self
.
assertIn
(
choice
.
random
(),
[
1
,
2
,
3
,
4
])
self
.
assertIn
(
choice
.
random
(
1234
),
[
1
,
2
,
3
,
4
])
def
test_int_range
(
self
):
int_range
=
tv
.
IntRange
(
"int_range"
,
start
=
1
,
stop
=
4
,
default
=
2
)
int_range
=
tv
.
IntRange
.
from_state
(
int_range
.
get_state
())
self
.
assertEqual
(
int_range
.
default
,
2
)
self
.
assertIn
(
int_range
.
random
(),
[
1
,
2
,
3
,
4
])
self
.
assertIn
(
int_range
.
random
(
1234
),
[
1
,
2
,
3
,
4
])
self
.
assertNotEqual
(
int_range
.
default
,
4
)
int_range
=
tv
.
IntRange
(
"int_range"
,
start
=
1
,
stop
=
8
,
step
=
2
,
default
=
3
,
endpoint
=
True
)
int_range
=
tv
.
IntRange
.
from_state
(
int_range
.
get_state
())
self
.
assertEqual
(
int_range
.
default
,
3
)
self
.
assertIn
(
int_range
.
random
(),
[
1
,
3
,
5
,
7
])
self
.
assertIn
(
int_range
.
random
(
1234
),
[
1
,
3
,
5
,
7
])
self
.
assertNotEqual
(
int_range
.
default
,
2
)
def
test_float_range
(
self
):
float_range
=
tv
.
FloatRange
(
"float_range"
,
start
=
0.4
,
stop
=
4.4
,
default
=
2.0
)
float_range
=
tv
.
FloatRange
.
from_state
(
float_range
.
get_state
())
self
.
assertEqual
(
float_range
.
default
,
2.0
)
self
.
assertGreater
(
float_range
.
random
(),
0.4
)
self
.
assertLess
(
float_range
.
random
(
1234
),
4.4
)
self
.
assertNotAlmostEqual
(
float_range
.
random
(),
1
)
self
.
assertNotAlmostEqual
(
float_range
.
random
(),
4.4
)
float_range
=
tv
.
FloatRange
(
"float_range"
,
start
=
0.4
,
stop
=
8.4
,
step
=
2.0
,
default
=
3.0
,
endpoint
=
True
)
float_range
=
tv
.
FloatRange
.
from_state
(
float_range
.
get_state
())
self
.
assertEqual
(
float_range
.
default
,
3.0
)
self
.
assertGreater
(
float_range
.
random
(),
0.4
)
self
.
assertLessEqual
(
float_range
.
random
(
1234
),
8.4
)
self
.
assertNotAlmostEqual
(
float_range
.
random
(),
2
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/setup.py.in
浏览文件 @
f84b54eb
...
...
@@ -300,6 +300,7 @@ packages=['paddle',
'paddle.distributed.fleet.meta_parallel.parallel_layers',
'paddle.distributed.auto_parallel',
'paddle.distributed.auto_parallel.operators',
'paddle.distributed.auto_parallel.tuner',
'paddle.distributed.passes',
'paddle.framework',
'paddle.jit',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录