Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b4a3dab7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b4a3dab7
编写于
6月 07, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
6月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cuda graph] Add cuda graph attr to op desc (#43228)
上级
2922985a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
237 addition
and
26 deletion
+237
-26
python/paddle/device/cuda/graphs.py
python/paddle/device/cuda/graphs.py
+20
-0
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+111
-26
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+35
-0
python/paddle/fluid/tests/unittests/test_cuda_graph_partial_graph_static.py
...d/tests/unittests/test_cuda_graph_partial_graph_static.py
+71
-0
未找到文件。
python/paddle/device/cuda/graphs.py
浏览文件 @
b4a3dab7
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
os
import
paddle
from
paddle.fluid.core
import
is_compiled_with_cuda
,
is_compiled_with_rocm
,
CUDAPlace
if
is_compiled_with_cuda
()
and
not
is_compiled_with_rocm
():
...
...
@@ -28,6 +29,7 @@ else:
ALL_MODES
=
[
"global"
,
"thread_local"
,
"relaxed"
]
cuda_graph_id
=
0
class
CUDAGraph
:
...
...
@@ -68,6 +70,24 @@ class CUDAGraph:
def
wrap_cuda_graph
(
function
,
mode
=
"thread_local"
,
memory_pool
=
"default"
):
assert
mode
in
ALL_MODES
if
not
paddle
.
in_dynamic_mode
():
# static mode
from
paddle.fluid.framework
import
_cuda_graph_guard
global
cuda_graph_id
graph_id
=
str
(
cuda_graph_id
)
cuda_graph_id
+=
1
if
memory_pool
==
'default'
:
memory_pool_id
=
0
elif
memory_pool
==
'new'
:
memory_pool_id
=
CoreCUDAGraph
.
gen_new_memory_pool_id
()
else
:
raise
ValueError
(
"memory_pool should be one of default or new under static mode, but got"
,
memory_pool
)
return
_cuda_graph_guard
(
mode
+
';'
+
str
(
memory_pool_id
)
+
';'
+
graph_id
)(
lambda
*
args
,
**
kwargs
:
function
(
*
args
,
**
kwargs
))
from
paddle.jit
import
to_static
from
paddle.nn
import
Layer
new_function
=
to_static
(
function
)
...
...
python/paddle/fluid/backward.py
浏览文件 @
b4a3dab7
...
...
@@ -236,7 +236,11 @@ def _pretty_op_desc_(op_desc, prefix):
return
out_s
def
_add_needed_descs_to_block
(
descs
,
block
,
main_block
,
in_memory_vars
):
def
_add_needed_descs_to_block
(
descs
,
block
,
main_block
,
in_memory_vars
,
grad_op_id_to_fwd_op
=
None
):
if
len
(
descs
)
==
0
:
return
[]
result_descs
=
[]
...
...
@@ -244,8 +248,11 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
backward
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
for
desc
in
descs
:
origin_desc
=
desc
origin_is_operator
=
False
if
isinstance
(
desc
,
framework
.
Operator
):
desc
=
desc
.
desc
origin_is_operator
=
True
if
isinstance
(
desc
,
tuple
):
desc
=
desc
[
0
]
is_needed
=
False
...
...
@@ -255,6 +262,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
if
name
not
in
in_memory_vars
:
is_needed
=
True
if
is_needed
:
if
origin_is_operator
and
grad_op_id_to_fwd_op
is
not
None
:
grad_op_id_to_fwd_op
[
desc
.
original_id
()]
=
origin_desc
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
...
...
@@ -264,7 +273,7 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
return
result_descs
def
_add_descs_to_block
(
descs
,
block
):
def
_add_descs_to_block
(
descs
,
block
,
grad_op_id_to_fwd_op
=
None
):
if
len
(
descs
)
==
0
:
return
[]
result_descs
=
[]
...
...
@@ -273,6 +282,9 @@ def _add_descs_to_block(descs, block):
backward
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
for
desc
in
descs
:
if
isinstance
(
desc
,
framework
.
Operator
):
# for recompute, should record recompute ops
if
grad_op_id_to_fwd_op
is
not
None
:
grad_op_id_to_fwd_op
[
desc
.
desc
.
original_id
()]
=
desc
desc
=
desc
.
desc
if
isinstance
(
desc
,
tuple
):
desc
=
desc
[
0
]
...
...
@@ -489,7 +501,10 @@ def _accumulate_gradients_by_add_ops_(var_name,
renamed_vars
[
var_name
]
=
[
var_name
]
def
_addup_repetitive_outputs_
(
op_descs
,
block_idx
,
grad_var_to_var
=
None
):
def
_addup_repetitive_outputs_
(
op_descs
,
block_idx
,
grad_var_to_var
=
None
,
grad_op_id_to_fwd_op
=
None
):
"""
In backward part, an variable may be the output of more than one ops.
And one op may yield its multiple outputs to the same variable.
...
...
@@ -500,6 +515,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
grad_var_to_var(dict): used to build the mapping between grad var name and forward var name.
Only for auto parallel.
"""
_MAX_ADD_NUM_
=
framework
.
_global_flags
()[
'FLAGS_max_inplace_grad_add'
]
#pending_sum_ops = []
pending_sum_ops
=
collections
.
OrderedDict
()
...
...
@@ -604,6 +620,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
len
(
op_descs
),
var_device
[
var_name
])
op_descs_len
=
len
(
op_descs
)
# sum_op descs are sorted according to their insert position
for
key
,
value
in
collections
.
OrderedDict
(
reversed
(
list
(
pending_sum_ops
.
items
()))).
items
():
...
...
@@ -614,12 +631,18 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
# If not reverse, we first insert 'a' at idx 1, it becomes [0, 1, 'a', 2], and then insert 'b' at idx 2, it becomes [0, 1, 'a', 'b', 2].
idx
=
key
for
i
,
op
in
enumerate
(
value
):
# update the mapping between fwd and bwd
target_idx
=
idx
-
1
if
idx
==
op_descs_len
else
idx
+
i
if
grad_op_id_to_fwd_op
is
not
None
and
grad_op_id_to_fwd_op
.
get
(
op_descs
[
target_idx
].
original_id
(),
None
)
is
not
None
:
grad_op_id_to_fwd_op
[
op
.
original_id
()]
=
grad_op_id_to_fwd_op
[
op_descs
[
target_idx
].
original_id
()]
op_descs
.
insert
(
idx
+
i
,
op
)
return
op_descs
def
_remove_no_grad_branch_
(
op_descs
,
no_grad_set
):
def
_remove_no_grad_branch_
(
op_descs
,
no_grad_set
,
grad_op_id_to_fwd_op
=
None
):
"""
Remove unnecessary grad ops
A grad op can be removed in two cases:
...
...
@@ -653,9 +676,14 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
x_in
=
_strip_grad_suffix_
(
arg
)
# the reason should be: arg can be input of another grad op
# and the op is a not-to-remove op
to_insert
.
append
(
(
_create_op_desc_
(
"fill_zeros_like"
,
{
"X"
:
[
x_in
]},
{
"Out"
:
[
arg
]},
{}),
idx
))
new_op_desc
=
_create_op_desc_
(
"fill_zeros_like"
,
{
"X"
:
[
x_in
]},
{
"Out"
:
[
arg
]},
{})
# update the mapping between fwd and bwd
if
grad_op_id_to_fwd_op
is
not
None
and
grad_op_id_to_fwd_op
.
get
(
op_desc
.
original_id
(),
None
)
is
not
None
:
grad_op_id_to_fwd_op
[
new_op_desc
.
original_id
(
)]
=
grad_op_id_to_fwd_op
[
op_desc
.
original_id
()]
to_insert
.
append
((
new_op_desc
,
idx
))
list
([
op_descs
.
insert
(
p
[
1
],
p
[
0
])
for
p
in
reversed
(
to_insert
)])
...
...
@@ -794,9 +822,13 @@ def serialize_op_decs(op_desc):
return
proto
.
__str__
()
def
_append_backward_ops_with_checkpoints_
(
block
,
ops
,
target_block
,
no_grad_dict
,
grad_to_var
,
checkpoints
):
def
_append_backward_ops_with_checkpoints_
(
block
,
ops
,
target_block
,
no_grad_dict
,
grad_to_var
,
checkpoints
,
grad_op_id_to_fwd_op
=
None
):
"""
Create grad ops with forward ops, and insert them into given block
...
...
@@ -926,12 +958,19 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# record the mapping between fwd and bwd
if
grad_op_id_to_fwd_op
is
not
None
:
for
op_desc
in
grad_op_desc
:
grad_op_id_to_fwd_op
[
op_desc
.
original_id
()]
=
op
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
,
grad_op_id_to_fwd_op
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
...
...
@@ -945,12 +984,19 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# record the mapping between fwd and bwd
if
grad_op_id_to_fwd_op
is
not
None
:
for
op_desc
in
grad_op_desc
:
grad_op_id_to_fwd_op
[
op_desc
.
original_id
()]
=
op
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
,
grad_op_id_to_fwd_op
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
...
...
@@ -984,8 +1030,10 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
# 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs
=
_add_needed_descs_to_block
(
ff_ops
,
buffer_block
,
block
,
vars_in_memory
)
added_descs
=
_add_descs_to_block
(
ff_ops
,
local_block
)
vars_in_memory
,
grad_op_id_to_fwd_op
)
added_descs
=
_add_descs_to_block
(
ff_ops
,
local_block
,
grad_op_id_to_fwd_op
)
# 3.b. rename all non-checkpoint variables in recomputation ops
for
key
in
var_name_dict
:
...
...
@@ -999,6 +1047,12 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op_desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# record the mapping between fwd and bwd
if
grad_op_id_to_fwd_op
is
not
None
:
for
g_op_desc
in
grad_op_desc
:
grad_op_id_to_fwd_op
[
g_op_desc
.
original_id
(
)]
=
grad_op_id_to_fwd_op
[
op_desc
.
original_id
()]
# Set device for grad_op according to forward Op
if
op_desc
.
has_attr
(
device_attr_name
):
op_device
=
op_desc
.
attr
(
device_attr_name
)
...
...
@@ -1011,11 +1065,14 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
grad_to_var
.
update
(
op_grad_to_var
)
# 3.d. add sum op for repetitive_outputs
grad_op_descs
=
_addup_repetitive_outputs_
(
grad_op_descs
,
block
.
idx
)
grad_op_descs
=
_addup_repetitive_outputs_
(
grad_op_descs
,
block
.
idx
,
grad_op_id_to_fwd_op
=
grad_op_id_to_fwd_op
)
# 4) remove no grad branch as it is in _remove_no_grad_branch_
grad_op_descs
=
_remove_no_grad_branch_
(
grad_op_descs
,
no_grad_dict
[
block
.
idx
])
added_descs
=
_add_descs_to_block
(
grad_op_descs
,
target_block
)
no_grad_dict
[
block
.
idx
],
grad_op_id_to_fwd_op
)
added_descs
=
_add_descs_to_block
(
grad_op_descs
,
target_block
,
grad_op_id_to_fwd_op
)
return
program_stat
,
checkpoints_name
,
vars_should_be_hold
,
recompute_segments
...
...
@@ -1090,7 +1147,8 @@ def _append_backward_ops_(block,
input_grad_names_set
=
None
,
op_path_dict
=
None
,
distop_context
=
None
,
rename_var_map
=
None
):
rename_var_map
=
None
,
grad_op_id_to_fwd_op
=
None
):
"""
Create all grad ops, and insert them into given block
...
...
@@ -1152,9 +1210,15 @@ def _append_backward_ops_(block,
pre_input_grad_names_set
=
copy
.
copy
(
input_grad_names_set
)
input_grad_names_set
=
None
sub_block_path
=
op_path_dict
[
op
.
_block_attr_id
(
"sub_block"
)]
_append_backward_ops_
(
sub_block
,
sub_block_path
,
grad_sub_block
,
no_grad_dict
,
grad_to_var
,
callbacks
,
input_grad_names_set
,
op_path_dict
)
_append_backward_ops_
(
sub_block
,
sub_block_path
,
grad_sub_block
,
no_grad_dict
,
grad_to_var
,
callbacks
,
input_grad_names_set
,
op_path_dict
,
grad_op_id_to_fwd_op
=
grad_op_id_to_fwd_op
)
input_grad_names_set
=
pre_input_grad_names_set
program
.
_rollback
()
...
...
@@ -1164,6 +1228,11 @@ def _append_backward_ops_(block,
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
grad_sub_block_list
)
# record the mapping between fwd and bwd
if
grad_op_id_to_fwd_op
is
not
None
:
for
op_desc
in
grad_op_desc
:
grad_op_id_to_fwd_op
[
op_desc
.
original_id
()]
=
op
# Build the mapping between the forward op and backward op (Only for auto parallel)
if
distop_context
is
not
None
:
update_distop_context
(
distop_context
,
op_grad_to_var
,
...
...
@@ -1251,13 +1320,17 @@ def _append_backward_ops_(block,
grad_var_to_var
=
distop_context
.
grad_var_to_var
[
program
.
_appending_grad_times
]
# sum parameter's gradients' var given multiple var gradient
grad_op_descs
=
_addup_repetitive_outputs_
(
grad_op_descs
,
block
.
idx
,
grad_var_to_var
)
grad_op_descs
=
_addup_repetitive_outputs_
(
grad_op_descs
,
block
.
idx
,
grad_var_to_var
,
grad_op_id_to_fwd_op
=
grad_op_id_to_fwd_op
)
# if all outputs of the grad op are in no_grad_set, then just remove and fill zero
# if all inputs of the grad op are in no_grad_set, just remove this op
grad_op_descs
=
_remove_no_grad_branch_
(
grad_op_descs
,
no_grad_dict
[
block
.
idx
])
no_grad_dict
[
block
.
idx
],
grad_op_id_to_fwd_op
)
# remove some backward ops
not_need_ops
=
_find_not_need_ops
(
grad_op_descs
,
ops
,
input_grad_names_set
)
...
...
@@ -1585,6 +1658,9 @@ def append_backward(loss,
p_g_list6 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights, no_grad_set=set(all_weights))
"""
grad_op_id_to_fwd_op
=
{
}
# for cuda graph usage, recording the mapping between grad op original id to fwd op
check_type
(
loss
,
'loss'
,
framework
.
Variable
,
'paddle.static.append_backward'
)
...
...
@@ -1644,7 +1720,9 @@ def append_backward(loss,
grad_to_var
=
dict
()
# pass the cuda_graph_attr to the fill_constant which generates the loss_grad
op_desc
=
_create_loss_op_desc_
(
loss
)
grad_op_id_to_fwd_op
[
op_desc
.
original_id
()]
=
loss
.
op
target_grad_block
.
desc
.
append_op
().
copy_from
(
op_desc
)
for
block_idx
in
son_parent_block_idx_dict
:
...
...
@@ -1690,7 +1768,8 @@ def append_backward(loss,
root_block
,
no_grad_dict
,
grad_to_var
,
checkpoints
)
checkpoints
,
grad_op_id_to_fwd_op
)
else
:
_append_backward_ops_
(
block
,
# the block where forward ops are in
...
...
@@ -1702,7 +1781,7 @@ def append_backward(loss,
input_grad_names_set
=
input_grad_names_set
,
op_path_dict
=
op_path_dict
,
distop_context
=
distop_context
,
)
grad_op_id_to_fwd_op
=
grad_op_id_to_fwd_op
)
grad_info_map
=
dict
()
...
...
@@ -1722,6 +1801,12 @@ def append_backward(loss,
program
.
current_block_idx
=
current_block_idx
program
.
_sync_with_cpp
()
# for cuda graph, copy the cuda graph attr from forward op to backward op
for
op
in
target_grad_block
.
ops
:
if
grad_op_id_to_fwd_op
.
get
(
op
.
desc
.
original_id
(),
None
)
is
not
None
:
fwd_op
=
grad_op_id_to_fwd_op
[
op
.
desc
.
original_id
()]
op
.
_cuda_graph_attr
=
fwd_op
.
_cuda_graph_attr
if
parameter_list
is
not
None
:
check_type
(
parameter_list
,
'parameter_list'
,
(
list
,
tuple
,
set
),
'fluid.backward.append_backward'
)
...
...
python/paddle/fluid/framework.py
浏览文件 @
b4a3dab7
...
...
@@ -81,6 +81,7 @@ global_prog_seed = 0
_current_pipeline_stage
=
None
_already_patch_eager_tensor
=
False
_already_patch_varbase
=
False
_current_cuda_graph_mode
=
None
_global_flags_
=
core
.
globals
()
# Some explanation of our execution system 2022.03
...
...
@@ -2622,6 +2623,9 @@ class Operator(object):
op_attrs
=
dict
()
del
attrs
# attr for static mode cuda graph
self
.
_cuda_graph_attr
=
_current_cuda_graph_mode
op_maker
=
core
.
op_proto_and_checker_maker
if
op_maker
.
kOpRoleAttrName
()
not
in
op_attrs
:
...
...
@@ -7017,6 +7021,37 @@ def device_guard(device=None):
switch_device
(
pre_device
)
def
_switch_cuda_graph_mode
(
cuda_graph_attr
):
global
_current_cuda_graph_mode
pre_mode
=
_current_cuda_graph_mode
_current_cuda_graph_mode
=
cuda_graph_attr
return
pre_mode
@
signature_safe_contextmanager
def
_cuda_graph_guard
(
cuda_graph_attr
=
None
):
"""
Note:
The API only supports static mode.
A context manager that specifies the cuda_graph_mode which indicating the cuda graph capture under static mode.
Args:
cuda_graph_attr(str|None): The cuda graph attr with the format of:
cuda_graph_capture_mode;memory_pool_id;cuda_graph_id
"""
assert
not
_non_static_mode
(
),
"cuda_graph_guard only works under static mode"
assert
core
.
is_compiled_with_cuda
(
),
"cuda_graph_guard context can be only used when Paddle is compiled with cuda"
pre_mode
=
_switch_cuda_graph_mode
(
cuda_graph_attr
)
try
:
yield
finally
:
_switch_cuda_graph_mode
(
pre_mode
)
def
set_flags
(
flags
):
"""
This function sets the GFlags value in Paddle.
...
...
python/paddle/fluid/tests/unittests/test_cuda_graph_partial_graph_static.py
0 → 100644
浏览文件 @
b4a3dab7
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
unittest
import
numpy
as
np
from
paddle.device.cuda.graphs
import
wrap_cuda_graph
,
is_cuda_graph_supported
paddle
.
enable_static
()
class
SimpleModel
(
nn
.
Layer
):
def
__init__
(
self
,
in_size
,
out_size
):
super
(
SimpleModel
,
self
).
__init__
()
self
.
linear
=
nn
.
Linear
(
in_size
,
out_size
)
self
.
dropout_1
=
paddle
.
nn
.
Dropout
(
0.1
)
self
.
relu
=
nn
.
ReLU
()
self
.
dropout_2
=
paddle
.
nn
.
Dropout
(
0.5
)
self
.
gelu
=
nn
.
GELU
()
def
forward
(
self
,
x
):
x
=
self
.
linear
(
x
)
x
=
self
.
dropout_1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
dropout_2
(
x
)
x
=
self
.
gelu
(
x
)
return
x
class
TestCudaGraphAttrAll
(
unittest
.
TestCase
):
def
test_all_program
(
self
):
if
not
is_cuda_graph_supported
():
return
main_prog
=
paddle
.
static
.
Program
()
start_prog
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_prog
,
start_prog
):
model
=
SimpleModel
(
10
,
20
)
cuda_graph_model
=
wrap_cuda_graph
(
model
)
x
=
paddle
.
static
.
data
(
shape
=
[
3
,
10
],
dtype
=
'float32'
,
name
=
'x'
)
y
=
cuda_graph_model
(
x
)
loss
=
paddle
.
mean
(
y
)
opt
=
paddle
.
optimizer
.
SGD
()
opt
.
minimize
(
loss
)
block
=
main_prog
.
global_block
()
for
op
in
block
.
ops
:
if
op
.
_cuda_graph_attr
is
None
:
# the loss and opt are not wrapped
assert
op
.
type
in
[
'sgd'
,
'reduce_mean'
,
'fill_constant'
,
'reduce_mean_grad'
]
else
:
assert
op
.
_cuda_graph_attr
==
'thread_local;0;0'
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录