未验证 提交 b4a3dab7 编写于 作者: Y Yuang Liu 提交者: GitHub

[cuda graph] Add cuda graph attr to op desc (#43228)

上级 2922985a
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import paddle
from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace
if is_compiled_with_cuda() and not is_compiled_with_rocm(): if is_compiled_with_cuda() and not is_compiled_with_rocm():
...@@ -28,6 +29,7 @@ else: ...@@ -28,6 +29,7 @@ else:
ALL_MODES = ["global", "thread_local", "relaxed"] ALL_MODES = ["global", "thread_local", "relaxed"]
cuda_graph_id = 0
class CUDAGraph: class CUDAGraph:
...@@ -68,6 +70,24 @@ class CUDAGraph: ...@@ -68,6 +70,24 @@ class CUDAGraph:
def wrap_cuda_graph(function, mode="thread_local", memory_pool="default"): def wrap_cuda_graph(function, mode="thread_local", memory_pool="default"):
assert mode in ALL_MODES assert mode in ALL_MODES
if not paddle.in_dynamic_mode():
# static mode
from paddle.fluid.framework import _cuda_graph_guard
global cuda_graph_id
graph_id = str(cuda_graph_id)
cuda_graph_id += 1
if memory_pool == 'default':
memory_pool_id = 0
elif memory_pool == 'new':
memory_pool_id = CoreCUDAGraph.gen_new_memory_pool_id()
else:
raise ValueError(
"memory_pool should be one of default or new under static mode, but got",
memory_pool)
return _cuda_graph_guard(
mode + ';' + str(memory_pool_id) + ';' +
graph_id)(lambda *args, **kwargs: function(*args, **kwargs))
from paddle.jit import to_static from paddle.jit import to_static
from paddle.nn import Layer from paddle.nn import Layer
new_function = to_static(function) new_function = to_static(function)
......
...@@ -236,7 +236,11 @@ def _pretty_op_desc_(op_desc, prefix): ...@@ -236,7 +236,11 @@ def _pretty_op_desc_(op_desc, prefix):
return out_s return out_s
def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): def _add_needed_descs_to_block(descs,
block,
main_block,
in_memory_vars,
grad_op_id_to_fwd_op=None):
if len(descs) == 0: if len(descs) == 0:
return [] return []
result_descs = [] result_descs = []
...@@ -244,8 +248,11 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): ...@@ -244,8 +248,11 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
core.op_proto_and_checker_maker.kOpRoleAttrName() core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward backward = core.op_proto_and_checker_maker.OpRole.Backward
for desc in descs: for desc in descs:
origin_desc = desc
origin_is_operator = False
if isinstance(desc, framework.Operator): if isinstance(desc, framework.Operator):
desc = desc.desc desc = desc.desc
origin_is_operator = True
if isinstance(desc, tuple): if isinstance(desc, tuple):
desc = desc[0] desc = desc[0]
is_needed = False is_needed = False
...@@ -255,6 +262,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): ...@@ -255,6 +262,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
if name not in in_memory_vars: if name not in in_memory_vars:
is_needed = True is_needed = True
if is_needed: if is_needed:
if origin_is_operator and grad_op_id_to_fwd_op is not None:
grad_op_id_to_fwd_op[desc.original_id()] = origin_desc
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc) new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
...@@ -264,7 +273,7 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): ...@@ -264,7 +273,7 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
return result_descs return result_descs
def _add_descs_to_block(descs, block): def _add_descs_to_block(descs, block, grad_op_id_to_fwd_op=None):
if len(descs) == 0: if len(descs) == 0:
return [] return []
result_descs = [] result_descs = []
...@@ -273,6 +282,9 @@ def _add_descs_to_block(descs, block): ...@@ -273,6 +282,9 @@ def _add_descs_to_block(descs, block):
backward = core.op_proto_and_checker_maker.OpRole.Backward backward = core.op_proto_and_checker_maker.OpRole.Backward
for desc in descs: for desc in descs:
if isinstance(desc, framework.Operator): if isinstance(desc, framework.Operator):
# for recompute, should record recompute ops
if grad_op_id_to_fwd_op is not None:
grad_op_id_to_fwd_op[desc.desc.original_id()] = desc
desc = desc.desc desc = desc.desc
if isinstance(desc, tuple): if isinstance(desc, tuple):
desc = desc[0] desc = desc[0]
...@@ -489,7 +501,10 @@ def _accumulate_gradients_by_add_ops_(var_name, ...@@ -489,7 +501,10 @@ def _accumulate_gradients_by_add_ops_(var_name,
renamed_vars[var_name] = [var_name] renamed_vars[var_name] = [var_name]
def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None): def _addup_repetitive_outputs_(op_descs,
block_idx,
grad_var_to_var=None,
grad_op_id_to_fwd_op=None):
""" """
In backward part, an variable may be the output of more than one ops. In backward part, an variable may be the output of more than one ops.
And one op may yield its multiple outputs to the same variable. And one op may yield its multiple outputs to the same variable.
...@@ -500,6 +515,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None): ...@@ -500,6 +515,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
grad_var_to_var(dict): used to build the mapping between grad var name and forward var name. grad_var_to_var(dict): used to build the mapping between grad var name and forward var name.
Only for auto parallel. Only for auto parallel.
""" """
_MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add'] _MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add']
#pending_sum_ops = [] #pending_sum_ops = []
pending_sum_ops = collections.OrderedDict() pending_sum_ops = collections.OrderedDict()
...@@ -604,6 +620,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None): ...@@ -604,6 +620,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
len(op_descs), len(op_descs),
var_device[var_name]) var_device[var_name])
op_descs_len = len(op_descs)
# sum_op descs are sorted according to their insert position # sum_op descs are sorted according to their insert position
for key, value in collections.OrderedDict( for key, value in collections.OrderedDict(
reversed(list(pending_sum_ops.items()))).items(): reversed(list(pending_sum_ops.items()))).items():
...@@ -614,12 +631,18 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None): ...@@ -614,12 +631,18 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
# If not reverse, we first insert 'a' at idx 1, it becomes [0, 1, 'a', 2], and then insert 'b' at idx 2, it becomes [0, 1, 'a', 'b', 2]. # If not reverse, we first insert 'a' at idx 1, it becomes [0, 1, 'a', 2], and then insert 'b' at idx 2, it becomes [0, 1, 'a', 'b', 2].
idx = key idx = key
for i, op in enumerate(value): for i, op in enumerate(value):
# update the mapping between fwd and bwd
target_idx = idx - 1 if idx == op_descs_len else idx + i
if grad_op_id_to_fwd_op is not None and grad_op_id_to_fwd_op.get(
op_descs[target_idx].original_id(), None) is not None:
grad_op_id_to_fwd_op[op.original_id()] = grad_op_id_to_fwd_op[
op_descs[target_idx].original_id()]
op_descs.insert(idx + i, op) op_descs.insert(idx + i, op)
return op_descs return op_descs
def _remove_no_grad_branch_(op_descs, no_grad_set): def _remove_no_grad_branch_(op_descs, no_grad_set, grad_op_id_to_fwd_op=None):
""" """
Remove unnecessary grad ops Remove unnecessary grad ops
A grad op can be removed in two cases: A grad op can be removed in two cases:
...@@ -653,9 +676,14 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): ...@@ -653,9 +676,14 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
x_in = _strip_grad_suffix_(arg) x_in = _strip_grad_suffix_(arg)
# the reason should be: arg can be input of another grad op # the reason should be: arg can be input of another grad op
# and the op is a not-to-remove op # and the op is a not-to-remove op
to_insert.append( new_op_desc = _create_op_desc_("fill_zeros_like", {"X": [x_in]},
(_create_op_desc_("fill_zeros_like", {"X": [x_in]}, {"Out": [arg]}, {})
{"Out": [arg]}, {}), idx)) # update the mapping between fwd and bwd
if grad_op_id_to_fwd_op is not None and grad_op_id_to_fwd_op.get(
op_desc.original_id(), None) is not None:
grad_op_id_to_fwd_op[new_op_desc.original_id(
)] = grad_op_id_to_fwd_op[op_desc.original_id()]
to_insert.append((new_op_desc, idx))
list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)]) list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)])
...@@ -794,9 +822,13 @@ def serialize_op_decs(op_desc): ...@@ -794,9 +822,13 @@ def serialize_op_decs(op_desc):
return proto.__str__() return proto.__str__()
def _append_backward_ops_with_checkpoints_(block, ops, target_block, def _append_backward_ops_with_checkpoints_(block,
no_grad_dict, grad_to_var, ops,
checkpoints): target_block,
no_grad_dict,
grad_to_var,
checkpoints,
grad_op_id_to_fwd_op=None):
""" """
Create grad ops with forward ops, and insert them into given block Create grad ops with forward ops, and insert them into given block
...@@ -926,12 +958,19 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block, ...@@ -926,12 +958,19 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
_pretty_op_desc_(op.desc, "with_sub_block")) _pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), []) op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# record the mapping between fwd and bwd
if grad_op_id_to_fwd_op is not None:
for op_desc in grad_op_desc:
grad_op_id_to_fwd_op[op_desc.original_id()] = op
# Set device for grad_op according to forward Op # Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name): if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name) op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc: for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device) op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block) added_descs = _add_descs_to_block(grad_op_desc, local_block,
grad_op_id_to_fwd_op)
grad_op_descs.extend(added_descs) grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
...@@ -945,12 +984,19 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block, ...@@ -945,12 +984,19 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
_pretty_op_desc_(op.desc, "with_sub_block")) _pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), []) op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# record the mapping between fwd and bwd
if grad_op_id_to_fwd_op is not None:
for op_desc in grad_op_desc:
grad_op_id_to_fwd_op[op_desc.original_id()] = op
# Set device for grad_op according to forward Op # Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name): if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name) op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc: for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device) op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block) added_descs = _add_descs_to_block(grad_op_desc, local_block,
grad_op_id_to_fwd_op)
grad_op_descs.extend(added_descs) grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
...@@ -984,8 +1030,10 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block, ...@@ -984,8 +1030,10 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
# 3.a. add ops in current recompute_segment as forward recomputation ops # 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block, buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block,
vars_in_memory) vars_in_memory,
added_descs = _add_descs_to_block(ff_ops, local_block) grad_op_id_to_fwd_op)
added_descs = _add_descs_to_block(ff_ops, local_block,
grad_op_id_to_fwd_op)
# 3.b. rename all non-checkpoint variables in recomputation ops # 3.b. rename all non-checkpoint variables in recomputation ops
for key in var_name_dict: for key in var_name_dict:
...@@ -999,6 +1047,12 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block, ...@@ -999,6 +1047,12 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op_desc, cpt.to_text(no_grad_dict[block.idx]), []) op_desc, cpt.to_text(no_grad_dict[block.idx]), [])
# record the mapping between fwd and bwd
if grad_op_id_to_fwd_op is not None:
for g_op_desc in grad_op_desc:
grad_op_id_to_fwd_op[g_op_desc.original_id(
)] = grad_op_id_to_fwd_op[op_desc.original_id()]
# Set device for grad_op according to forward Op # Set device for grad_op according to forward Op
if op_desc.has_attr(device_attr_name): if op_desc.has_attr(device_attr_name):
op_device = op_desc.attr(device_attr_name) op_device = op_desc.attr(device_attr_name)
...@@ -1011,11 +1065,14 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block, ...@@ -1011,11 +1065,14 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
# 3.d. add sum op for repetitive_outputs # 3.d. add sum op for repetitive_outputs
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx) grad_op_descs = _addup_repetitive_outputs_(
grad_op_descs, block.idx, grad_op_id_to_fwd_op=grad_op_id_to_fwd_op)
# 4) remove no grad branch as it is in _remove_no_grad_branch_ # 4) remove no grad branch as it is in _remove_no_grad_branch_
grad_op_descs = _remove_no_grad_branch_(grad_op_descs, grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx]) no_grad_dict[block.idx],
added_descs = _add_descs_to_block(grad_op_descs, target_block) grad_op_id_to_fwd_op)
added_descs = _add_descs_to_block(grad_op_descs, target_block,
grad_op_id_to_fwd_op)
return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments
...@@ -1090,7 +1147,8 @@ def _append_backward_ops_(block, ...@@ -1090,7 +1147,8 @@ def _append_backward_ops_(block,
input_grad_names_set=None, input_grad_names_set=None,
op_path_dict=None, op_path_dict=None,
distop_context=None, distop_context=None,
rename_var_map=None): rename_var_map=None,
grad_op_id_to_fwd_op=None):
""" """
Create all grad ops, and insert them into given block Create all grad ops, and insert them into given block
...@@ -1152,9 +1210,15 @@ def _append_backward_ops_(block, ...@@ -1152,9 +1210,15 @@ def _append_backward_ops_(block,
pre_input_grad_names_set = copy.copy(input_grad_names_set) pre_input_grad_names_set = copy.copy(input_grad_names_set)
input_grad_names_set = None input_grad_names_set = None
sub_block_path = op_path_dict[op._block_attr_id("sub_block")] sub_block_path = op_path_dict[op._block_attr_id("sub_block")]
_append_backward_ops_(sub_block, sub_block_path, grad_sub_block, _append_backward_ops_(sub_block,
no_grad_dict, grad_to_var, callbacks, sub_block_path,
input_grad_names_set, op_path_dict) grad_sub_block,
no_grad_dict,
grad_to_var,
callbacks,
input_grad_names_set,
op_path_dict,
grad_op_id_to_fwd_op=grad_op_id_to_fwd_op)
input_grad_names_set = pre_input_grad_names_set input_grad_names_set = pre_input_grad_names_set
program._rollback() program._rollback()
...@@ -1164,6 +1228,11 @@ def _append_backward_ops_(block, ...@@ -1164,6 +1228,11 @@ def _append_backward_ops_(block,
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list) op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
# record the mapping between fwd and bwd
if grad_op_id_to_fwd_op is not None:
for op_desc in grad_op_desc:
grad_op_id_to_fwd_op[op_desc.original_id()] = op
# Build the mapping between the forward op and backward op (Only for auto parallel) # Build the mapping between the forward op and backward op (Only for auto parallel)
if distop_context is not None: if distop_context is not None:
update_distop_context(distop_context, op_grad_to_var, update_distop_context(distop_context, op_grad_to_var,
...@@ -1251,13 +1320,17 @@ def _append_backward_ops_(block, ...@@ -1251,13 +1320,17 @@ def _append_backward_ops_(block,
grad_var_to_var = distop_context.grad_var_to_var[ grad_var_to_var = distop_context.grad_var_to_var[
program._appending_grad_times] program._appending_grad_times]
# sum parameter's gradients' var given multiple var gradient # sum parameter's gradients' var given multiple var gradient
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx, grad_op_descs = _addup_repetitive_outputs_(
grad_var_to_var) grad_op_descs,
block.idx,
grad_var_to_var,
grad_op_id_to_fwd_op=grad_op_id_to_fwd_op)
# if all outputs of the grad op are in no_grad_set, then just remove and fill zero # if all outputs of the grad op are in no_grad_set, then just remove and fill zero
# if all inputs of the grad op are in no_grad_set, just remove this op # if all inputs of the grad op are in no_grad_set, just remove this op
grad_op_descs = _remove_no_grad_branch_(grad_op_descs, grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx]) no_grad_dict[block.idx],
grad_op_id_to_fwd_op)
# remove some backward ops # remove some backward ops
not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set) not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set)
...@@ -1585,6 +1658,9 @@ def append_backward(loss, ...@@ -1585,6 +1658,9 @@ def append_backward(loss,
p_g_list6 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights, no_grad_set=set(all_weights)) p_g_list6 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights, no_grad_set=set(all_weights))
""" """
grad_op_id_to_fwd_op = {
} # for cuda graph usage, recording the mapping between grad op original id to fwd op
check_type(loss, 'loss', framework.Variable, check_type(loss, 'loss', framework.Variable,
'paddle.static.append_backward') 'paddle.static.append_backward')
...@@ -1644,7 +1720,9 @@ def append_backward(loss, ...@@ -1644,7 +1720,9 @@ def append_backward(loss,
grad_to_var = dict() grad_to_var = dict()
# pass the cuda_graph_attr to the fill_constant which generates the loss_grad
op_desc = _create_loss_op_desc_(loss) op_desc = _create_loss_op_desc_(loss)
grad_op_id_to_fwd_op[op_desc.original_id()] = loss.op
target_grad_block.desc.append_op().copy_from(op_desc) target_grad_block.desc.append_op().copy_from(op_desc)
for block_idx in son_parent_block_idx_dict: for block_idx in son_parent_block_idx_dict:
...@@ -1690,7 +1768,8 @@ def append_backward(loss, ...@@ -1690,7 +1768,8 @@ def append_backward(loss,
root_block, root_block,
no_grad_dict, no_grad_dict,
grad_to_var, grad_to_var,
checkpoints) checkpoints,
grad_op_id_to_fwd_op)
else: else:
_append_backward_ops_( _append_backward_ops_(
block, # the block where forward ops are in block, # the block where forward ops are in
...@@ -1702,7 +1781,7 @@ def append_backward(loss, ...@@ -1702,7 +1781,7 @@ def append_backward(loss,
input_grad_names_set=input_grad_names_set, input_grad_names_set=input_grad_names_set,
op_path_dict=op_path_dict, op_path_dict=op_path_dict,
distop_context=distop_context, distop_context=distop_context,
) grad_op_id_to_fwd_op=grad_op_id_to_fwd_op)
grad_info_map = dict() grad_info_map = dict()
...@@ -1722,6 +1801,12 @@ def append_backward(loss, ...@@ -1722,6 +1801,12 @@ def append_backward(loss,
program.current_block_idx = current_block_idx program.current_block_idx = current_block_idx
program._sync_with_cpp() program._sync_with_cpp()
# for cuda graph, copy the cuda graph attr from forward op to backward op
for op in target_grad_block.ops:
if grad_op_id_to_fwd_op.get(op.desc.original_id(), None) is not None:
fwd_op = grad_op_id_to_fwd_op[op.desc.original_id()]
op._cuda_graph_attr = fwd_op._cuda_graph_attr
if parameter_list is not None: if parameter_list is not None:
check_type(parameter_list, 'parameter_list', (list, tuple, set), check_type(parameter_list, 'parameter_list', (list, tuple, set),
'fluid.backward.append_backward') 'fluid.backward.append_backward')
......
...@@ -81,6 +81,7 @@ global_prog_seed = 0 ...@@ -81,6 +81,7 @@ global_prog_seed = 0
_current_pipeline_stage = None _current_pipeline_stage = None
_already_patch_eager_tensor = False _already_patch_eager_tensor = False
_already_patch_varbase = False _already_patch_varbase = False
_current_cuda_graph_mode = None
_global_flags_ = core.globals() _global_flags_ = core.globals()
# Some explanation of our execution system 2022.03 # Some explanation of our execution system 2022.03
...@@ -2622,6 +2623,9 @@ class Operator(object): ...@@ -2622,6 +2623,9 @@ class Operator(object):
op_attrs = dict() op_attrs = dict()
del attrs del attrs
# attr for static mode cuda graph
self._cuda_graph_attr = _current_cuda_graph_mode
op_maker = core.op_proto_and_checker_maker op_maker = core.op_proto_and_checker_maker
if op_maker.kOpRoleAttrName() not in op_attrs: if op_maker.kOpRoleAttrName() not in op_attrs:
...@@ -7017,6 +7021,37 @@ def device_guard(device=None): ...@@ -7017,6 +7021,37 @@ def device_guard(device=None):
switch_device(pre_device) switch_device(pre_device)
def _switch_cuda_graph_mode(cuda_graph_attr):
global _current_cuda_graph_mode
pre_mode = _current_cuda_graph_mode
_current_cuda_graph_mode = cuda_graph_attr
return pre_mode
@signature_safe_contextmanager
def _cuda_graph_guard(cuda_graph_attr=None):
"""
Note:
The API only supports static mode.
A context manager that specifies the cuda_graph_mode which indicating the cuda graph capture under static mode.
Args:
cuda_graph_attr(str|None): The cuda graph attr with the format of:
cuda_graph_capture_mode;memory_pool_id;cuda_graph_id
"""
assert not _non_static_mode(
), "cuda_graph_guard only works under static mode"
assert core.is_compiled_with_cuda(
), "cuda_graph_guard context can be only used when Paddle is compiled with cuda"
pre_mode = _switch_cuda_graph_mode(cuda_graph_attr)
try:
yield
finally:
_switch_cuda_graph_mode(pre_mode)
def set_flags(flags): def set_flags(flags):
""" """
This function sets the GFlags value in Paddle. This function sets the GFlags value in Paddle.
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import unittest
import numpy as np
from paddle.device.cuda.graphs import wrap_cuda_graph, is_cuda_graph_supported
paddle.enable_static()
class SimpleModel(nn.Layer):
def __init__(self, in_size, out_size):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(in_size, out_size)
self.dropout_1 = paddle.nn.Dropout(0.1)
self.relu = nn.ReLU()
self.dropout_2 = paddle.nn.Dropout(0.5)
self.gelu = nn.GELU()
def forward(self, x):
x = self.linear(x)
x = self.dropout_1(x)
x = self.relu(x)
x = self.dropout_2(x)
x = self.gelu(x)
return x
class TestCudaGraphAttrAll(unittest.TestCase):
def test_all_program(self):
if not is_cuda_graph_supported():
return
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
model = SimpleModel(10, 20)
cuda_graph_model = wrap_cuda_graph(model)
x = paddle.static.data(shape=[3, 10], dtype='float32', name='x')
y = cuda_graph_model(x)
loss = paddle.mean(y)
opt = paddle.optimizer.SGD()
opt.minimize(loss)
block = main_prog.global_block()
for op in block.ops:
if op._cuda_graph_attr is None:
# the loss and opt are not wrapped
assert op.type in [
'sgd', 'reduce_mean', 'fill_constant',
'reduce_mean_grad'
]
else:
assert op._cuda_graph_attr == 'thread_local;0;0'
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册