Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b4a3dab7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b4a3dab7
编写于
6月 07, 2022
作者:
Y
Yuang Liu
提交者:
GitHub
6月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cuda graph] Add cuda graph attr to op desc (#43228)
上级
2922985a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
237 addition
and
26 deletion
+237
-26
python/paddle/device/cuda/graphs.py
python/paddle/device/cuda/graphs.py
+20
-0
python/paddle/fluid/backward.py
python/paddle/fluid/backward.py
+111
-26
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+35
-0
python/paddle/fluid/tests/unittests/test_cuda_graph_partial_graph_static.py
...d/tests/unittests/test_cuda_graph_partial_graph_static.py
+71
-0
未找到文件。
python/paddle/device/cuda/graphs.py
浏览文件 @
b4a3dab7
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
os
import
paddle
from
paddle.fluid.core
import
is_compiled_with_cuda
,
is_compiled_with_rocm
,
CUDAPlace
if
is_compiled_with_cuda
()
and
not
is_compiled_with_rocm
():
...
...
@@ -28,6 +29,7 @@ else:
ALL_MODES
=
[
"global"
,
"thread_local"
,
"relaxed"
]
cuda_graph_id
=
0
class
CUDAGraph
:
...
...
@@ -68,6 +70,24 @@ class CUDAGraph:
def
wrap_cuda_graph
(
function
,
mode
=
"thread_local"
,
memory_pool
=
"default"
):
assert
mode
in
ALL_MODES
if
not
paddle
.
in_dynamic_mode
():
# static mode
from
paddle.fluid.framework
import
_cuda_graph_guard
global
cuda_graph_id
graph_id
=
str
(
cuda_graph_id
)
cuda_graph_id
+=
1
if
memory_pool
==
'default'
:
memory_pool_id
=
0
elif
memory_pool
==
'new'
:
memory_pool_id
=
CoreCUDAGraph
.
gen_new_memory_pool_id
()
else
:
raise
ValueError
(
"memory_pool should be one of default or new under static mode, but got"
,
memory_pool
)
return
_cuda_graph_guard
(
mode
+
';'
+
str
(
memory_pool_id
)
+
';'
+
graph_id
)(
lambda
*
args
,
**
kwargs
:
function
(
*
args
,
**
kwargs
))
from
paddle.jit
import
to_static
from
paddle.nn
import
Layer
new_function
=
to_static
(
function
)
...
...
python/paddle/fluid/backward.py
浏览文件 @
b4a3dab7
...
...
@@ -236,7 +236,11 @@ def _pretty_op_desc_(op_desc, prefix):
return
out_s
def
_add_needed_descs_to_block
(
descs
,
block
,
main_block
,
in_memory_vars
):
def
_add_needed_descs_to_block
(
descs
,
block
,
main_block
,
in_memory_vars
,
grad_op_id_to_fwd_op
=
None
):
if
len
(
descs
)
==
0
:
return
[]
result_descs
=
[]
...
...
@@ -244,8 +248,11 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
core
.
op_proto_and_checker_maker
.
kOpRoleAttrName
()
backward
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
for
desc
in
descs
:
origin_desc
=
desc
origin_is_operator
=
False
if
isinstance
(
desc
,
framework
.
Operator
):
desc
=
desc
.
desc
origin_is_operator
=
True
if
isinstance
(
desc
,
tuple
):
desc
=
desc
[
0
]
is_needed
=
False
...
...
@@ -255,6 +262,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
if
name
not
in
in_memory_vars
:
is_needed
=
True
if
is_needed
:
if
origin_is_operator
and
grad_op_id_to_fwd_op
is
not
None
:
grad_op_id_to_fwd_op
[
desc
.
original_id
()]
=
origin_desc
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
desc
)
new_op_desc
.
_set_attr
(
op_role_attr_name
,
backward
)
...
...
@@ -264,7 +273,7 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
return
result_descs
def
_add_descs_to_block
(
descs
,
block
):
def
_add_descs_to_block
(
descs
,
block
,
grad_op_id_to_fwd_op
=
None
):
if
len
(
descs
)
==
0
:
return
[]
result_descs
=
[]
...
...
@@ -273,6 +282,9 @@ def _add_descs_to_block(descs, block):
backward
=
core
.
op_proto_and_checker_maker
.
OpRole
.
Backward
for
desc
in
descs
:
if
isinstance
(
desc
,
framework
.
Operator
):
# for recompute, should record recompute ops
if
grad_op_id_to_fwd_op
is
not
None
:
grad_op_id_to_fwd_op
[
desc
.
desc
.
original_id
()]
=
desc
desc
=
desc
.
desc
if
isinstance
(
desc
,
tuple
):
desc
=
desc
[
0
]
...
...
@@ -489,7 +501,10 @@ def _accumulate_gradients_by_add_ops_(var_name,
renamed_vars
[
var_name
]
=
[
var_name
]
def
_addup_repetitive_outputs_
(
op_descs
,
block_idx
,
grad_var_to_var
=
None
):
def
_addup_repetitive_outputs_
(
op_descs
,
block_idx
,
grad_var_to_var
=
None
,
grad_op_id_to_fwd_op
=
None
):
"""
In backward part, an variable may be the output of more than one ops.
And one op may yield its multiple outputs to the same variable.
...
...
@@ -500,6 +515,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
grad_var_to_var(dict): used to build the mapping between grad var name and forward var name.
Only for auto parallel.
"""
_MAX_ADD_NUM_
=
framework
.
_global_flags
()[
'FLAGS_max_inplace_grad_add'
]
#pending_sum_ops = []
pending_sum_ops
=
collections
.
OrderedDict
()
...
...
@@ -604,6 +620,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
len
(
op_descs
),
var_device
[
var_name
])
op_descs_len
=
len
(
op_descs
)
# sum_op descs are sorted according to their insert position
for
key
,
value
in
collections
.
OrderedDict
(
reversed
(
list
(
pending_sum_ops
.
items
()))).
items
():
...
...
@@ -614,12 +631,18 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
# If not reverse, we first insert 'a' at idx 1, it becomes [0, 1, 'a', 2], and then insert 'b' at idx 2, it becomes [0, 1, 'a', 'b', 2].
idx
=
key
for
i
,
op
in
enumerate
(
value
):
# update the mapping between fwd and bwd
target_idx
=
idx
-
1
if
idx
==
op_descs_len
else
idx
+
i
if
grad_op_id_to_fwd_op
is
not
None
and
grad_op_id_to_fwd_op
.
get
(
op_descs
[
target_idx
].
original_id
(),
None
)
is
not
None
:
grad_op_id_to_fwd_op
[
op
.
original_id
()]
=
grad_op_id_to_fwd_op
[
op_descs
[
target_idx
].
original_id
()]
op_descs
.
insert
(
idx
+
i
,
op
)
return
op_descs
def
_remove_no_grad_branch_
(
op_descs
,
no_grad_set
):
def
_remove_no_grad_branch_
(
op_descs
,
no_grad_set
,
grad_op_id_to_fwd_op
=
None
):
"""
Remove unnecessary grad ops
A grad op can be removed in two cases:
...
...
@@ -653,9 +676,14 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
x_in
=
_strip_grad_suffix_
(
arg
)
# the reason should be: arg can be input of another grad op
# and the op is a not-to-remove op
to_insert
.
append
(
(
_create_op_desc_
(
"fill_zeros_like"
,
{
"X"
:
[
x_in
]},
{
"Out"
:
[
arg
]},
{}),
idx
))
new_op_desc
=
_create_op_desc_
(
"fill_zeros_like"
,
{
"X"
:
[
x_in
]},
{
"Out"
:
[
arg
]},
{})
# update the mapping between fwd and bwd
if
grad_op_id_to_fwd_op
is
not
None
and
grad_op_id_to_fwd_op
.
get
(
op_desc
.
original_id
(),
None
)
is
not
None
:
grad_op_id_to_fwd_op
[
new_op_desc
.
original_id
(
)]
=
grad_op_id_to_fwd_op
[
op_desc
.
original_id
()]
to_insert
.
append
((
new_op_desc
,
idx
))
list
([
op_descs
.
insert
(
p
[
1
],
p
[
0
])
for
p
in
reversed
(
to_insert
)])
...
...
@@ -794,9 +822,13 @@ def serialize_op_decs(op_desc):
return
proto
.
__str__
()
def
_append_backward_ops_with_checkpoints_
(
block
,
ops
,
target_block
,
no_grad_dict
,
grad_to_var
,
checkpoints
):
def
_append_backward_ops_with_checkpoints_
(
block
,
ops
,
target_block
,
no_grad_dict
,
grad_to_var
,
checkpoints
,
grad_op_id_to_fwd_op
=
None
):
"""
Create grad ops with forward ops, and insert them into given block
...
...
@@ -926,12 +958,19 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# record the mapping between fwd and bwd
if
grad_op_id_to_fwd_op
is
not
None
:
for
op_desc
in
grad_op_desc
:
grad_op_id_to_fwd_op
[
op_desc
.
original_id
()]
=
op
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
,
grad_op_id_to_fwd_op
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
...
...
@@ -945,12 +984,19 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
_pretty_op_desc_
(
op
.
desc
,
"with_sub_block"
))
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# record the mapping between fwd and bwd
if
grad_op_id_to_fwd_op
is
not
None
:
for
op_desc
in
grad_op_desc
:
grad_op_id_to_fwd_op
[
op_desc
.
original_id
()]
=
op
# Set device for grad_op according to forward Op
if
op
.
desc
.
has_attr
(
device_attr_name
):
op_device
=
op
.
desc
.
attr
(
device_attr_name
)
for
op_desc
in
grad_op_desc
:
op_desc
.
_set_attr
(
device_attr_name
,
op_device
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
)
added_descs
=
_add_descs_to_block
(
grad_op_desc
,
local_block
,
grad_op_id_to_fwd_op
)
grad_op_descs
.
extend
(
added_descs
)
grad_to_var
.
update
(
op_grad_to_var
)
...
...
@@ -984,8 +1030,10 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
# 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs
=
_add_needed_descs_to_block
(
ff_ops
,
buffer_block
,
block
,
vars_in_memory
)
added_descs
=
_add_descs_to_block
(
ff_ops
,
local_block
)
vars_in_memory
,
grad_op_id_to_fwd_op
)
added_descs
=
_add_descs_to_block
(
ff_ops
,
local_block
,
grad_op_id_to_fwd_op
)
# 3.b. rename all non-checkpoint variables in recomputation ops
for
key
in
var_name_dict
:
...
...
@@ -999,6 +1047,12 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op_desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
[])
# record the mapping between fwd and bwd
if
grad_op_id_to_fwd_op
is
not
None
:
for
g_op_desc
in
grad_op_desc
:
grad_op_id_to_fwd_op
[
g_op_desc
.
original_id
(
)]
=
grad_op_id_to_fwd_op
[
op_desc
.
original_id
()]
# Set device for grad_op according to forward Op
if
op_desc
.
has_attr
(
device_attr_name
):
op_device
=
op_desc
.
attr
(
device_attr_name
)
...
...
@@ -1011,11 +1065,14 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
grad_to_var
.
update
(
op_grad_to_var
)
# 3.d. add sum op for repetitive_outputs
grad_op_descs
=
_addup_repetitive_outputs_
(
grad_op_descs
,
block
.
idx
)
grad_op_descs
=
_addup_repetitive_outputs_
(
grad_op_descs
,
block
.
idx
,
grad_op_id_to_fwd_op
=
grad_op_id_to_fwd_op
)
# 4) remove no grad branch as it is in _remove_no_grad_branch_
grad_op_descs
=
_remove_no_grad_branch_
(
grad_op_descs
,
no_grad_dict
[
block
.
idx
])
added_descs
=
_add_descs_to_block
(
grad_op_descs
,
target_block
)
no_grad_dict
[
block
.
idx
],
grad_op_id_to_fwd_op
)
added_descs
=
_add_descs_to_block
(
grad_op_descs
,
target_block
,
grad_op_id_to_fwd_op
)
return
program_stat
,
checkpoints_name
,
vars_should_be_hold
,
recompute_segments
...
...
@@ -1090,7 +1147,8 @@ def _append_backward_ops_(block,
input_grad_names_set
=
None
,
op_path_dict
=
None
,
distop_context
=
None
,
rename_var_map
=
None
):
rename_var_map
=
None
,
grad_op_id_to_fwd_op
=
None
):
"""
Create all grad ops, and insert them into given block
...
...
@@ -1152,9 +1210,15 @@ def _append_backward_ops_(block,
pre_input_grad_names_set
=
copy
.
copy
(
input_grad_names_set
)
input_grad_names_set
=
None
sub_block_path
=
op_path_dict
[
op
.
_block_attr_id
(
"sub_block"
)]
_append_backward_ops_
(
sub_block
,
sub_block_path
,
grad_sub_block
,
no_grad_dict
,
grad_to_var
,
callbacks
,
input_grad_names_set
,
op_path_dict
)
_append_backward_ops_
(
sub_block
,
sub_block_path
,
grad_sub_block
,
no_grad_dict
,
grad_to_var
,
callbacks
,
input_grad_names_set
,
op_path_dict
,
grad_op_id_to_fwd_op
=
grad_op_id_to_fwd_op
)
input_grad_names_set
=
pre_input_grad_names_set
program
.
_rollback
()
...
...
@@ -1164,6 +1228,11 @@ def _append_backward_ops_(block,
grad_op_desc
,
op_grad_to_var
=
core
.
get_grad_op_desc
(
op
.
desc
,
cpt
.
to_text
(
no_grad_dict
[
block
.
idx
]),
grad_sub_block_list
)
# record the mapping between fwd and bwd
if
grad_op_id_to_fwd_op
is
not
None
:
for
op_desc
in
grad_op_desc
:
grad_op_id_to_fwd_op
[
op_desc
.
original_id
()]
=
op
# Build the mapping between the forward op and backward op (Only for auto parallel)
if
distop_context
is
not
None
:
update_distop_context
(
distop_context
,
op_grad_to_var
,
...
...
@@ -1251,13 +1320,17 @@ def _append_backward_ops_(block,
grad_var_to_var
=
distop_context
.
grad_var_to_var
[
program
.
_appending_grad_times
]
# sum parameter's gradients' var given multiple var gradient
grad_op_descs
=
_addup_repetitive_outputs_
(
grad_op_descs
,
block
.
idx
,
grad_var_to_var
)
grad_op_descs
=
_addup_repetitive_outputs_
(
grad_op_descs
,
block
.
idx
,
grad_var_to_var
,
grad_op_id_to_fwd_op
=
grad_op_id_to_fwd_op
)
# if all outputs of the grad op are in no_grad_set, then just remove and fill zero
# if all inputs of the grad op are in no_grad_set, just remove this op
grad_op_descs
=
_remove_no_grad_branch_
(
grad_op_descs
,
no_grad_dict
[
block
.
idx
])
no_grad_dict
[
block
.
idx
],
grad_op_id_to_fwd_op
)
# remove some backward ops
not_need_ops
=
_find_not_need_ops
(
grad_op_descs
,
ops
,
input_grad_names_set
)
...
...
@@ -1585,6 +1658,9 @@ def append_backward(loss,
p_g_list6 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights, no_grad_set=set(all_weights))
"""
grad_op_id_to_fwd_op
=
{
}
# for cuda graph usage, recording the mapping between grad op original id to fwd op
check_type
(
loss
,
'loss'
,
framework
.
Variable
,
'paddle.static.append_backward'
)
...
...
@@ -1644,7 +1720,9 @@ def append_backward(loss,
grad_to_var
=
dict
()
# pass the cuda_graph_attr to the fill_constant which generates the loss_grad
op_desc
=
_create_loss_op_desc_
(
loss
)
grad_op_id_to_fwd_op
[
op_desc
.
original_id
()]
=
loss
.
op
target_grad_block
.
desc
.
append_op
().
copy_from
(
op_desc
)
for
block_idx
in
son_parent_block_idx_dict
:
...
...
@@ -1690,7 +1768,8 @@ def append_backward(loss,
root_block
,
no_grad_dict
,
grad_to_var
,
checkpoints
)
checkpoints
,
grad_op_id_to_fwd_op
)
else
:
_append_backward_ops_
(
block
,
# the block where forward ops are in
...
...
@@ -1702,7 +1781,7 @@ def append_backward(loss,
input_grad_names_set
=
input_grad_names_set
,
op_path_dict
=
op_path_dict
,
distop_context
=
distop_context
,
)
grad_op_id_to_fwd_op
=
grad_op_id_to_fwd_op
)
grad_info_map
=
dict
()
...
...
@@ -1722,6 +1801,12 @@ def append_backward(loss,
program
.
current_block_idx
=
current_block_idx
program
.
_sync_with_cpp
()
# for cuda graph, copy the cuda graph attr from forward op to backward op
for
op
in
target_grad_block
.
ops
:
if
grad_op_id_to_fwd_op
.
get
(
op
.
desc
.
original_id
(),
None
)
is
not
None
:
fwd_op
=
grad_op_id_to_fwd_op
[
op
.
desc
.
original_id
()]
op
.
_cuda_graph_attr
=
fwd_op
.
_cuda_graph_attr
if
parameter_list
is
not
None
:
check_type
(
parameter_list
,
'parameter_list'
,
(
list
,
tuple
,
set
),
'fluid.backward.append_backward'
)
...
...
python/paddle/fluid/framework.py
浏览文件 @
b4a3dab7
...
...
@@ -81,6 +81,7 @@ global_prog_seed = 0
_current_pipeline_stage
=
None
_already_patch_eager_tensor
=
False
_already_patch_varbase
=
False
_current_cuda_graph_mode
=
None
_global_flags_
=
core
.
globals
()
# Some explanation of our execution system 2022.03
...
...
@@ -2622,6 +2623,9 @@ class Operator(object):
op_attrs
=
dict
()
del
attrs
# attr for static mode cuda graph
self
.
_cuda_graph_attr
=
_current_cuda_graph_mode
op_maker
=
core
.
op_proto_and_checker_maker
if
op_maker
.
kOpRoleAttrName
()
not
in
op_attrs
:
...
...
@@ -7017,6 +7021,37 @@ def device_guard(device=None):
switch_device
(
pre_device
)
def
_switch_cuda_graph_mode
(
cuda_graph_attr
):
global
_current_cuda_graph_mode
pre_mode
=
_current_cuda_graph_mode
_current_cuda_graph_mode
=
cuda_graph_attr
return
pre_mode
@
signature_safe_contextmanager
def
_cuda_graph_guard
(
cuda_graph_attr
=
None
):
"""
Note:
The API only supports static mode.
A context manager that specifies the cuda_graph_mode which indicating the cuda graph capture under static mode.
Args:
cuda_graph_attr(str|None): The cuda graph attr with the format of:
cuda_graph_capture_mode;memory_pool_id;cuda_graph_id
"""
assert
not
_non_static_mode
(
),
"cuda_graph_guard only works under static mode"
assert
core
.
is_compiled_with_cuda
(
),
"cuda_graph_guard context can be only used when Paddle is compiled with cuda"
pre_mode
=
_switch_cuda_graph_mode
(
cuda_graph_attr
)
try
:
yield
finally
:
_switch_cuda_graph_mode
(
pre_mode
)
def
set_flags
(
flags
):
"""
This function sets the GFlags value in Paddle.
...
...
python/paddle/fluid/tests/unittests/test_cuda_graph_partial_graph_static.py
0 → 100644
浏览文件 @
b4a3dab7
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.nn
as
nn
import
unittest
import
numpy
as
np
from
paddle.device.cuda.graphs
import
wrap_cuda_graph
,
is_cuda_graph_supported
paddle
.
enable_static
()
class
SimpleModel
(
nn
.
Layer
):
def
__init__
(
self
,
in_size
,
out_size
):
super
(
SimpleModel
,
self
).
__init__
()
self
.
linear
=
nn
.
Linear
(
in_size
,
out_size
)
self
.
dropout_1
=
paddle
.
nn
.
Dropout
(
0.1
)
self
.
relu
=
nn
.
ReLU
()
self
.
dropout_2
=
paddle
.
nn
.
Dropout
(
0.5
)
self
.
gelu
=
nn
.
GELU
()
def
forward
(
self
,
x
):
x
=
self
.
linear
(
x
)
x
=
self
.
dropout_1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
dropout_2
(
x
)
x
=
self
.
gelu
(
x
)
return
x
class
TestCudaGraphAttrAll
(
unittest
.
TestCase
):
def
test_all_program
(
self
):
if
not
is_cuda_graph_supported
():
return
main_prog
=
paddle
.
static
.
Program
()
start_prog
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_prog
,
start_prog
):
model
=
SimpleModel
(
10
,
20
)
cuda_graph_model
=
wrap_cuda_graph
(
model
)
x
=
paddle
.
static
.
data
(
shape
=
[
3
,
10
],
dtype
=
'float32'
,
name
=
'x'
)
y
=
cuda_graph_model
(
x
)
loss
=
paddle
.
mean
(
y
)
opt
=
paddle
.
optimizer
.
SGD
()
opt
.
minimize
(
loss
)
block
=
main_prog
.
global_block
()
for
op
in
block
.
ops
:
if
op
.
_cuda_graph_attr
is
None
:
# the loss and opt are not wrapped
assert
op
.
type
in
[
'sgd'
,
'reduce_mean'
,
'fill_constant'
,
'reduce_mean_grad'
]
else
:
assert
op
.
_cuda_graph_attr
==
'thread_local;0;0'
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录