Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3206fa80
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3206fa80
编写于
4月 19, 2023
作者:
R
ronnywang
提交者:
GitHub
4月 19, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CustomDevice] add recompute support (#53044)
* [CustomDevice] add recompute support * update
上级
7e19d16f
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
86 addition
and
10 deletion
+86
-10
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+3
-0
paddle/fluid/pybind/generator_py.cc
paddle/fluid/pybind/generator_py.cc
+1
-0
paddle/phi/core/generator.cc
paddle/phi/core/generator.cc
+11
-0
paddle/phi/core/generator.h
paddle/phi/core/generator.h
+5
-0
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
...optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
+6
-1
python/paddle/distributed/fleet/recompute/recompute.py
python/paddle/distributed/fleet/recompute/recompute.py
+8
-2
python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
...on/paddle/distributed/fleet/utils/hybrid_parallel_util.py
+10
-5
python/paddle/framework/random.py
python/paddle/framework/random.py
+42
-2
未找到文件。
paddle/fluid/platform/device_context.cc
浏览文件 @
3206fa80
...
...
@@ -109,6 +109,9 @@ inline std::unique_ptr<DeviceContext> CreateDeviceContext(
dev_ctx
->
SetAllocator
(
instance
.
GetAllocator
(
p
).
get
());
dev_ctx
->
SetGenerator
(
phi
::
DefaultXPUGenerator
(
p
.
GetDeviceId
()).
get
());
#endif
}
else
if
(
p
.
GetType
()
==
phi
::
AllocationType
::
CUSTOM
)
{
dev_ctx
->
SetAllocator
(
instance
.
GetAllocator
(
p
).
get
());
dev_ctx
->
SetGenerator
(
phi
::
DefaultCustomDeviceGenerator
(
p
).
get
());
}
else
{
dev_ctx
->
SetAllocator
(
instance
.
GetAllocator
(
p
).
get
());
dev_ctx
->
SetGenerator
(
phi
::
DefaultCPUGenerator
().
get
());
...
...
paddle/fluid/pybind/generator_py.cc
浏览文件 @
3206fa80
...
...
@@ -88,6 +88,7 @@ void BindGenerator(py::module* m_ptr) {
m
.
def
(
"default_cpu_generator"
,
&
phi
::
DefaultCPUGenerator
);
m
.
def
(
"default_cuda_generator"
,
&
phi
::
DefaultCUDAGenerator
);
m
.
def
(
"default_xpu_generator"
,
&
phi
::
DefaultXPUGenerator
);
m
.
def
(
"default_custom_device_generator"
,
&
phi
::
DefaultCustomDeviceGenerator
);
m
.
def
(
"set_random_seed_generator"
,
&
phi
::
SetRandomSeedGenerator
);
m
.
def
(
"get_random_seed_generator"
,
&
phi
::
GetRandomSeedGenerator
);
}
...
...
paddle/phi/core/generator.cc
浏览文件 @
3206fa80
...
...
@@ -99,6 +99,17 @@ const std::shared_ptr<Generator>& DefaultCPUGenerator() {
return
default_cpu_generator
;
}
const
std
::
shared_ptr
<
Generator
>&
DefaultCustomDeviceGenerator
(
const
phi
::
CustomPlace
&
place
)
{
static
std
::
unordered_map
<
phi
::
Place
,
std
::
shared_ptr
<
Generator
>
,
phi
::
Place
::
Hash
>
generators
;
if
(
generators
.
find
(
place
)
==
generators
.
end
())
{
generators
.
insert
({
place
,
std
::
make_shared
<
Generator
>
(
GetRandomSeed
())});
}
return
generators
[
place
];
}
using
RNGMap
=
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
Generator
>>
;
static
RNGMap
&
GetRandomSeedGeneratorMap
()
{
...
...
paddle/phi/core/generator.h
浏览文件 @
3206fa80
...
...
@@ -25,6 +25,8 @@ limitations under the License. */
#include <typeinfo>
#include <utility>
#include "paddle/phi/common/place.h"
namespace
phi
{
class
Generator
{
...
...
@@ -80,6 +82,9 @@ const std::shared_ptr<Generator>& DefaultCUDAGenerator(int64_t device_id = -1);
const
std
::
shared_ptr
<
Generator
>&
DefaultXPUGenerator
(
int64_t
device_id
=
-
1
);
const
std
::
shared_ptr
<
Generator
>&
DefaultCustomDeviceGenerator
(
const
phi
::
CustomPlace
&
place
);
std
::
shared_ptr
<
std
::
mt19937_64
>
GetCPURandomEngine
(
uint64_t
);
const
std
::
shared_ptr
<
Generator
>&
SetRandomSeedGenerator
(
...
...
python/paddle/distributed/fleet/meta_optimizers/dygraph_optimizer/hybrid_parallel_optimizer.py
浏览文件 @
3206fa80
...
...
@@ -205,7 +205,12 @@ class HybridParallelClipGrad:
clip_var_fp16
=
paddle
.
cast
(
clip_var
,
paddle
.
float16
)
# bf16 is not supported on XPU now
if
not
paddle
.
is_compiled_with_xpu
():
if
not
(
paddle
.
is_compiled_with_xpu
()
or
isinstance
(
paddle
.
framework
.
_current_expected_place
(),
paddle
.
CustomPlace
)
):
clip_var_bf16
=
paddle
.
cast
(
clip_var
,
paddle
.
bfloat16
)
for
p
,
g
in
params_grads
:
if
g
is
None
:
...
...
python/paddle/distributed/fleet/recompute/recompute.py
浏览文件 @
3206fa80
...
...
@@ -222,13 +222,19 @@ def _recompute_without_reentrant(
if
preserve_rng_state
:
cur_device
=
paddle
.
get_device
()
if
'gpu:'
not
in
cur_device
:
if
'gpu:'
in
cur_device
:
fw_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
elif
(
cur_device
.
split
(
':'
)[
0
]
in
paddle
.
device
.
get_all_custom_device_type
()
):
fw_cuda_rng_state
=
paddle
.
get_rng_state
(
cur_device
)
else
:
raise
RuntimeError
(
"Recompute with RNG perserve is not support current device: {}."
.
format
(
cur_device
)
)
fw_cuda_rng_state
=
paddle
.
get_cuda_rng_state
()
fwd_cuda_rng_state_tracker
=
(
get_rng_state_tracker
().
get_states_tracker
()
)
...
...
python/paddle/distributed/fleet/utils/hybrid_parallel_util.py
浏览文件 @
3206fa80
...
...
@@ -154,14 +154,19 @@ def _broadcast_object_list_help(object_list, hcg):
def
broadcast_input_data
(
hcg
,
*
inputs
,
**
kwargs
):
cur_device
=
paddle
.
get_device
()
dev
=
cur_device
.
split
(
":"
)[
0
]
assert
dev
in
[
"xpu"
,
"gpu"
,
"npu"
,
],
f
"Only support xpu, gpu and npu now, but this is
{
dev
}
"
assert
(
dev
in
[
"xpu"
,
"gpu"
,
]
or
dev
in
paddle
.
device
.
get_all_custom_device_type
()
),
f
"Only support xpu, gpu and custom_device now, but this is
{
dev
}
"
dev_idx
=
int
(
cur_device
.
split
(
':'
)[
1
])
if
dev
==
"gpu"
:
place
=
paddle
.
CUDAPlace
(
dev_idx
)
elif
dev
in
paddle
.
device
.
get_all_custom_device_type
():
place
=
paddle
.
CustomPlace
(
dev
,
dev_idx
)
else
:
place
=
eval
(
f
"paddle.
{
dev
.
upper
()
}
Place"
)(
dev_idx
)
...
...
python/paddle/framework/random.py
浏览文件 @
3206fa80
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
# TODO: define random api
import
paddle
from
paddle
import
fluid
from
paddle.fluid
import
core
...
...
@@ -48,7 +49,18 @@ def seed(seed):
elif
core
.
is_compiled_with_xpu
():
for
i
in
range
(
core
.
get_xpu_device_count
()):
core
.
default_xpu_generator
(
i
).
manual_seed
(
seed
)
place
=
fluid
.
framework
.
_current_expected_place
()
if
isinstance
(
place
,
core
.
CustomPlace
):
dev_cnt
=
sum
(
[
place
.
get_device_type
()
==
s
.
split
(
':'
)[
0
]
for
s
in
core
.
get_available_custom_device
()
]
)
for
i
in
range
(
dev_cnt
):
core
.
default_custom_device_generator
(
core
.
CustomPlace
(
place
.
get_device_type
(),
i
)
).
manual_seed
(
seed
)
return
core
.
default_cpu_generator
().
manual_seed
(
seed
)
...
...
@@ -70,7 +82,7 @@ def get_rng_state(device=None):
if
device
is
None
:
place
=
fluid
.
framework
.
_current_expected_place
()
else
:
place
=
device
.
_convert_to_place
(
device
)
place
=
paddle
.
device
.
_convert_to_place
(
device
)
if
isinstance
(
place
,
core
.
CPUPlace
):
state_list
.
append
(
core
.
default_cpu_generator
().
get_state
())
...
...
@@ -80,6 +92,19 @@ def get_rng_state(device=None):
elif
isinstance
(
place
,
core
.
XPUPlace
):
for
i
in
range
(
core
.
get_xpu_device_count
()):
state_list
.
append
(
core
.
default_xpu_generator
(
i
).
get_state
())
elif
isinstance
(
place
,
core
.
CustomPlace
):
dev_cnt
=
sum
(
[
place
.
get_device_type
()
==
s
.
split
(
':'
)[
0
]
for
s
in
core
.
get_available_custom_device
()
]
)
for
i
in
range
(
dev_cnt
):
state_list
.
append
(
core
.
default_custom_device_generator
(
core
.
CustomPlace
(
place
.
get_device_type
(),
i
)
).
get_state
()
)
else
:
raise
ValueError
(
"get_rng_state is not implemented for current device: {}"
.
format
(
...
...
@@ -157,6 +182,21 @@ def set_rng_state(state_list, device=None):
)
for
i
in
range
(
core
.
get_xpu_device_count
()):
core
.
default_xpu_generator
(
i
).
set_state
(
state_list
[
i
])
elif
isinstance
(
place
,
core
.
CustomPlace
):
dev_cnt
=
sum
(
[
place
.
get_device_type
()
==
s
.
split
(
':'
)[
0
]
for
s
in
core
.
get_available_custom_device
()
]
)
if
not
len
(
state_list
)
==
dev_cnt
:
raise
ValueError
(
f
"Length of custom device state list shoule be equal to the
{
place
.
get_dtype_type
()
}
device count"
)
for
i
in
range
(
dev_cnt
):
core
.
default_custom_device_generator
(
core
.
CustomPlace
(
place
.
get_device_type
(),
i
)
).
set_state
(
state_list
[
i
])
elif
isinstance
(
place
,
core
.
CPUPlace
):
if
not
len
(
state_list
)
==
1
:
raise
ValueError
(
"Length of cpu state list shoule be equal to 1"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录