未验证 提交 b4a3dab7 编写于 作者: Y Yuang Liu 提交者: GitHub

[cuda graph] Add cuda graph attr to op desc (#43228)

上级 2922985a
......@@ -13,6 +13,7 @@
# limitations under the License.
import os
import paddle
from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace
if is_compiled_with_cuda() and not is_compiled_with_rocm():
......@@ -28,6 +29,7 @@ else:
ALL_MODES = ["global", "thread_local", "relaxed"]
cuda_graph_id = 0
class CUDAGraph:
......@@ -68,6 +70,24 @@ class CUDAGraph:
def wrap_cuda_graph(function, mode="thread_local", memory_pool="default"):
assert mode in ALL_MODES
if not paddle.in_dynamic_mode():
# static mode
from paddle.fluid.framework import _cuda_graph_guard
global cuda_graph_id
graph_id = str(cuda_graph_id)
cuda_graph_id += 1
if memory_pool == 'default':
memory_pool_id = 0
elif memory_pool == 'new':
memory_pool_id = CoreCUDAGraph.gen_new_memory_pool_id()
else:
raise ValueError(
"memory_pool should be one of default or new under static mode, but got",
memory_pool)
return _cuda_graph_guard(
mode + ';' + str(memory_pool_id) + ';' +
graph_id)(lambda *args, **kwargs: function(*args, **kwargs))
from paddle.jit import to_static
from paddle.nn import Layer
new_function = to_static(function)
......
......@@ -236,7 +236,11 @@ def _pretty_op_desc_(op_desc, prefix):
return out_s
def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
def _add_needed_descs_to_block(descs,
block,
main_block,
in_memory_vars,
grad_op_id_to_fwd_op=None):
if len(descs) == 0:
return []
result_descs = []
......@@ -244,8 +248,11 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
for desc in descs:
origin_desc = desc
origin_is_operator = False
if isinstance(desc, framework.Operator):
desc = desc.desc
origin_is_operator = True
if isinstance(desc, tuple):
desc = desc[0]
is_needed = False
......@@ -255,6 +262,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
if name not in in_memory_vars:
is_needed = True
if is_needed:
if origin_is_operator and grad_op_id_to_fwd_op is not None:
grad_op_id_to_fwd_op[desc.original_id()] = origin_desc
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward)
......@@ -264,7 +273,7 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
return result_descs
def _add_descs_to_block(descs, block):
def _add_descs_to_block(descs, block, grad_op_id_to_fwd_op=None):
if len(descs) == 0:
return []
result_descs = []
......@@ -273,6 +282,9 @@ def _add_descs_to_block(descs, block):
backward = core.op_proto_and_checker_maker.OpRole.Backward
for desc in descs:
if isinstance(desc, framework.Operator):
# for recompute, should record recompute ops
if grad_op_id_to_fwd_op is not None:
grad_op_id_to_fwd_op[desc.desc.original_id()] = desc
desc = desc.desc
if isinstance(desc, tuple):
desc = desc[0]
......@@ -489,7 +501,10 @@ def _accumulate_gradients_by_add_ops_(var_name,
renamed_vars[var_name] = [var_name]
def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
def _addup_repetitive_outputs_(op_descs,
block_idx,
grad_var_to_var=None,
grad_op_id_to_fwd_op=None):
"""
In backward part, an variable may be the output of more than one ops.
And one op may yield its multiple outputs to the same variable.
......@@ -500,6 +515,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
grad_var_to_var(dict): used to build the mapping between grad var name and forward var name.
Only for auto parallel.
"""
_MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add']
#pending_sum_ops = []
pending_sum_ops = collections.OrderedDict()
......@@ -604,6 +620,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
len(op_descs),
var_device[var_name])
op_descs_len = len(op_descs)
# sum_op descs are sorted according to their insert position
for key, value in collections.OrderedDict(
reversed(list(pending_sum_ops.items()))).items():
......@@ -614,12 +631,18 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None):
# If not reverse, we first insert 'a' at idx 1, it becomes [0, 1, 'a', 2], and then insert 'b' at idx 2, it becomes [0, 1, 'a', 'b', 2].
idx = key
for i, op in enumerate(value):
# update the mapping between fwd and bwd
target_idx = idx - 1 if idx == op_descs_len else idx + i
if grad_op_id_to_fwd_op is not None and grad_op_id_to_fwd_op.get(
op_descs[target_idx].original_id(), None) is not None:
grad_op_id_to_fwd_op[op.original_id()] = grad_op_id_to_fwd_op[
op_descs[target_idx].original_id()]
op_descs.insert(idx + i, op)
return op_descs
def _remove_no_grad_branch_(op_descs, no_grad_set):
def _remove_no_grad_branch_(op_descs, no_grad_set, grad_op_id_to_fwd_op=None):
"""
Remove unnecessary grad ops
A grad op can be removed in two cases:
......@@ -653,9 +676,14 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
x_in = _strip_grad_suffix_(arg)
# the reason should be: arg can be input of another grad op
# and the op is a not-to-remove op
to_insert.append(
(_create_op_desc_("fill_zeros_like", {"X": [x_in]},
{"Out": [arg]}, {}), idx))
new_op_desc = _create_op_desc_("fill_zeros_like", {"X": [x_in]},
{"Out": [arg]}, {})
# update the mapping between fwd and bwd
if grad_op_id_to_fwd_op is not None and grad_op_id_to_fwd_op.get(
op_desc.original_id(), None) is not None:
grad_op_id_to_fwd_op[new_op_desc.original_id(
)] = grad_op_id_to_fwd_op[op_desc.original_id()]
to_insert.append((new_op_desc, idx))
list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)])
......@@ -794,9 +822,13 @@ def serialize_op_decs(op_desc):
return proto.__str__()
def _append_backward_ops_with_checkpoints_(block, ops, target_block,
no_grad_dict, grad_to_var,
checkpoints):
def _append_backward_ops_with_checkpoints_(block,
ops,
target_block,
no_grad_dict,
grad_to_var,
checkpoints,
grad_op_id_to_fwd_op=None):
"""
Create grad ops with forward ops, and insert them into given block
......@@ -926,12 +958,19 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
_pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# record the mapping between fwd and bwd
if grad_op_id_to_fwd_op is not None:
for op_desc in grad_op_desc:
grad_op_id_to_fwd_op[op_desc.original_id()] = op
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block)
added_descs = _add_descs_to_block(grad_op_desc, local_block,
grad_op_id_to_fwd_op)
grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var)
......@@ -945,12 +984,19 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
_pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# record the mapping between fwd and bwd
if grad_op_id_to_fwd_op is not None:
for op_desc in grad_op_desc:
grad_op_id_to_fwd_op[op_desc.original_id()] = op
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block)
added_descs = _add_descs_to_block(grad_op_desc, local_block,
grad_op_id_to_fwd_op)
grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var)
......@@ -984,8 +1030,10 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
# 3.a. add ops in current recompute_segment as forward recomputation ops
buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block,
vars_in_memory)
added_descs = _add_descs_to_block(ff_ops, local_block)
vars_in_memory,
grad_op_id_to_fwd_op)
added_descs = _add_descs_to_block(ff_ops, local_block,
grad_op_id_to_fwd_op)
# 3.b. rename all non-checkpoint variables in recomputation ops
for key in var_name_dict:
......@@ -999,6 +1047,12 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op_desc, cpt.to_text(no_grad_dict[block.idx]), [])
# record the mapping between fwd and bwd
if grad_op_id_to_fwd_op is not None:
for g_op_desc in grad_op_desc:
grad_op_id_to_fwd_op[g_op_desc.original_id(
)] = grad_op_id_to_fwd_op[op_desc.original_id()]
# Set device for grad_op according to forward Op
if op_desc.has_attr(device_attr_name):
op_device = op_desc.attr(device_attr_name)
......@@ -1011,11 +1065,14 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block,
grad_to_var.update(op_grad_to_var)
# 3.d. add sum op for repetitive_outputs
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx)
grad_op_descs = _addup_repetitive_outputs_(
grad_op_descs, block.idx, grad_op_id_to_fwd_op=grad_op_id_to_fwd_op)
# 4) remove no grad branch as it is in _remove_no_grad_branch_
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx])
added_descs = _add_descs_to_block(grad_op_descs, target_block)
no_grad_dict[block.idx],
grad_op_id_to_fwd_op)
added_descs = _add_descs_to_block(grad_op_descs, target_block,
grad_op_id_to_fwd_op)
return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments
......@@ -1090,7 +1147,8 @@ def _append_backward_ops_(block,
input_grad_names_set=None,
op_path_dict=None,
distop_context=None,
rename_var_map=None):
rename_var_map=None,
grad_op_id_to_fwd_op=None):
"""
Create all grad ops, and insert them into given block
......@@ -1152,9 +1210,15 @@ def _append_backward_ops_(block,
pre_input_grad_names_set = copy.copy(input_grad_names_set)
input_grad_names_set = None
sub_block_path = op_path_dict[op._block_attr_id("sub_block")]
_append_backward_ops_(sub_block, sub_block_path, grad_sub_block,
no_grad_dict, grad_to_var, callbacks,
input_grad_names_set, op_path_dict)
_append_backward_ops_(sub_block,
sub_block_path,
grad_sub_block,
no_grad_dict,
grad_to_var,
callbacks,
input_grad_names_set,
op_path_dict,
grad_op_id_to_fwd_op=grad_op_id_to_fwd_op)
input_grad_names_set = pre_input_grad_names_set
program._rollback()
......@@ -1164,6 +1228,11 @@ def _append_backward_ops_(block,
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
# record the mapping between fwd and bwd
if grad_op_id_to_fwd_op is not None:
for op_desc in grad_op_desc:
grad_op_id_to_fwd_op[op_desc.original_id()] = op
# Build the mapping between the forward op and backward op (Only for auto parallel)
if distop_context is not None:
update_distop_context(distop_context, op_grad_to_var,
......@@ -1251,13 +1320,17 @@ def _append_backward_ops_(block,
grad_var_to_var = distop_context.grad_var_to_var[
program._appending_grad_times]
# sum parameter's gradients' var given multiple var gradient
grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx,
grad_var_to_var)
grad_op_descs = _addup_repetitive_outputs_(
grad_op_descs,
block.idx,
grad_var_to_var,
grad_op_id_to_fwd_op=grad_op_id_to_fwd_op)
# if all outputs of the grad op are in no_grad_set, then just remove and fill zero
# if all inputs of the grad op are in no_grad_set, just remove this op
grad_op_descs = _remove_no_grad_branch_(grad_op_descs,
no_grad_dict[block.idx])
no_grad_dict[block.idx],
grad_op_id_to_fwd_op)
# remove some backward ops
not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set)
......@@ -1585,6 +1658,9 @@ def append_backward(loss,
p_g_list6 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights, no_grad_set=set(all_weights))
"""
grad_op_id_to_fwd_op = {
} # for cuda graph usage, recording the mapping between grad op original id to fwd op
check_type(loss, 'loss', framework.Variable,
'paddle.static.append_backward')
......@@ -1644,7 +1720,9 @@ def append_backward(loss,
grad_to_var = dict()
# pass the cuda_graph_attr to the fill_constant which generates the loss_grad
op_desc = _create_loss_op_desc_(loss)
grad_op_id_to_fwd_op[op_desc.original_id()] = loss.op
target_grad_block.desc.append_op().copy_from(op_desc)
for block_idx in son_parent_block_idx_dict:
......@@ -1690,7 +1768,8 @@ def append_backward(loss,
root_block,
no_grad_dict,
grad_to_var,
checkpoints)
checkpoints,
grad_op_id_to_fwd_op)
else:
_append_backward_ops_(
block, # the block where forward ops are in
......@@ -1702,7 +1781,7 @@ def append_backward(loss,
input_grad_names_set=input_grad_names_set,
op_path_dict=op_path_dict,
distop_context=distop_context,
)
grad_op_id_to_fwd_op=grad_op_id_to_fwd_op)
grad_info_map = dict()
......@@ -1722,6 +1801,12 @@ def append_backward(loss,
program.current_block_idx = current_block_idx
program._sync_with_cpp()
# for cuda graph, copy the cuda graph attr from forward op to backward op
for op in target_grad_block.ops:
if grad_op_id_to_fwd_op.get(op.desc.original_id(), None) is not None:
fwd_op = grad_op_id_to_fwd_op[op.desc.original_id()]
op._cuda_graph_attr = fwd_op._cuda_graph_attr
if parameter_list is not None:
check_type(parameter_list, 'parameter_list', (list, tuple, set),
'fluid.backward.append_backward')
......
......@@ -81,6 +81,7 @@ global_prog_seed = 0
_current_pipeline_stage = None
_already_patch_eager_tensor = False
_already_patch_varbase = False
_current_cuda_graph_mode = None
_global_flags_ = core.globals()
# Some explanation of our execution system 2022.03
......@@ -2622,6 +2623,9 @@ class Operator(object):
op_attrs = dict()
del attrs
# attr for static mode cuda graph
self._cuda_graph_attr = _current_cuda_graph_mode
op_maker = core.op_proto_and_checker_maker
if op_maker.kOpRoleAttrName() not in op_attrs:
......@@ -7017,6 +7021,37 @@ def device_guard(device=None):
switch_device(pre_device)
def _switch_cuda_graph_mode(cuda_graph_attr):
global _current_cuda_graph_mode
pre_mode = _current_cuda_graph_mode
_current_cuda_graph_mode = cuda_graph_attr
return pre_mode
@signature_safe_contextmanager
def _cuda_graph_guard(cuda_graph_attr=None):
"""
Note:
The API only supports static mode.
A context manager that specifies the cuda_graph_mode which indicating the cuda graph capture under static mode.
Args:
cuda_graph_attr(str|None): The cuda graph attr with the format of:
cuda_graph_capture_mode;memory_pool_id;cuda_graph_id
"""
assert not _non_static_mode(
), "cuda_graph_guard only works under static mode"
assert core.is_compiled_with_cuda(
), "cuda_graph_guard context can be only used when Paddle is compiled with cuda"
pre_mode = _switch_cuda_graph_mode(cuda_graph_attr)
try:
yield
finally:
_switch_cuda_graph_mode(pre_mode)
def set_flags(flags):
"""
This function sets the GFlags value in Paddle.
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import unittest
import numpy as np
from paddle.device.cuda.graphs import wrap_cuda_graph, is_cuda_graph_supported
paddle.enable_static()
class SimpleModel(nn.Layer):
def __init__(self, in_size, out_size):
super(SimpleModel, self).__init__()
self.linear = nn.Linear(in_size, out_size)
self.dropout_1 = paddle.nn.Dropout(0.1)
self.relu = nn.ReLU()
self.dropout_2 = paddle.nn.Dropout(0.5)
self.gelu = nn.GELU()
def forward(self, x):
x = self.linear(x)
x = self.dropout_1(x)
x = self.relu(x)
x = self.dropout_2(x)
x = self.gelu(x)
return x
class TestCudaGraphAttrAll(unittest.TestCase):
def test_all_program(self):
if not is_cuda_graph_supported():
return
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, start_prog):
model = SimpleModel(10, 20)
cuda_graph_model = wrap_cuda_graph(model)
x = paddle.static.data(shape=[3, 10], dtype='float32', name='x')
y = cuda_graph_model(x)
loss = paddle.mean(y)
opt = paddle.optimizer.SGD()
opt.minimize(loss)
block = main_prog.global_block()
for op in block.ops:
if op._cuda_graph_attr is None:
# the loss and opt are not wrapped
assert op.type in [
'sgd', 'reduce_mean', 'fill_constant',
'reduce_mean_grad'
]
else:
assert op._cuda_graph_attr == 'thread_local;0;0'
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册