From b4a3dab727e5d5c50b040326ab9e52ba82b957f7 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 7 Jun 2022 16:20:37 +0800 Subject: [PATCH] [cuda graph] Add cuda graph attr to op desc (#43228) --- python/paddle/device/cuda/graphs.py | 20 +++ python/paddle/fluid/backward.py | 137 ++++++++++++++---- python/paddle/fluid/framework.py | 35 +++++ .../test_cuda_graph_partial_graph_static.py | 71 +++++++++ 4 files changed, 237 insertions(+), 26 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_cuda_graph_partial_graph_static.py diff --git a/python/paddle/device/cuda/graphs.py b/python/paddle/device/cuda/graphs.py index c6554d78fb8..dca32fb6bb8 100644 --- a/python/paddle/device/cuda/graphs.py +++ b/python/paddle/device/cuda/graphs.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import paddle from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace if is_compiled_with_cuda() and not is_compiled_with_rocm(): @@ -28,6 +29,7 @@ else: ALL_MODES = ["global", "thread_local", "relaxed"] +cuda_graph_id = 0 class CUDAGraph: @@ -68,6 +70,24 @@ class CUDAGraph: def wrap_cuda_graph(function, mode="thread_local", memory_pool="default"): assert mode in ALL_MODES + if not paddle.in_dynamic_mode(): + # static mode + from paddle.fluid.framework import _cuda_graph_guard + global cuda_graph_id + graph_id = str(cuda_graph_id) + cuda_graph_id += 1 + if memory_pool == 'default': + memory_pool_id = 0 + elif memory_pool == 'new': + memory_pool_id = CoreCUDAGraph.gen_new_memory_pool_id() + else: + raise ValueError( + "memory_pool should be one of default or new under static mode, but got", + memory_pool) + return _cuda_graph_guard( + mode + ';' + str(memory_pool_id) + ';' + + graph_id)(lambda *args, **kwargs: function(*args, **kwargs)) + from paddle.jit import to_static from paddle.nn import Layer new_function = to_static(function) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 0ca69b5f94d..c37ac87da71 100755 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -236,7 +236,11 @@ def _pretty_op_desc_(op_desc, prefix): return out_s -def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): +def _add_needed_descs_to_block(descs, + block, + main_block, + in_memory_vars, + grad_op_id_to_fwd_op=None): if len(descs) == 0: return [] result_descs = [] @@ -244,8 +248,11 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): core.op_proto_and_checker_maker.kOpRoleAttrName() backward = core.op_proto_and_checker_maker.OpRole.Backward for desc in descs: + origin_desc = desc + origin_is_operator = False if isinstance(desc, framework.Operator): desc = desc.desc + origin_is_operator = True if isinstance(desc, tuple): desc = desc[0] is_needed = False @@ -255,6 +262,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): if name not in in_memory_vars: is_needed = True if is_needed: + if origin_is_operator and grad_op_id_to_fwd_op is not None: + grad_op_id_to_fwd_op[desc.original_id()] = origin_desc new_op_desc = block.desc.append_op() new_op_desc.copy_from(desc) new_op_desc._set_attr(op_role_attr_name, backward) @@ -264,7 +273,7 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): return result_descs -def _add_descs_to_block(descs, block): +def _add_descs_to_block(descs, block, grad_op_id_to_fwd_op=None): if len(descs) == 0: return [] result_descs = [] @@ -273,6 +282,9 @@ def _add_descs_to_block(descs, block): backward = core.op_proto_and_checker_maker.OpRole.Backward for desc in descs: if isinstance(desc, framework.Operator): + # for recompute, should record recompute ops + if grad_op_id_to_fwd_op is not None: + grad_op_id_to_fwd_op[desc.desc.original_id()] = desc desc = desc.desc if isinstance(desc, tuple): desc = desc[0] @@ -489,7 +501,10 @@ def _accumulate_gradients_by_add_ops_(var_name, renamed_vars[var_name] = [var_name] -def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None): +def _addup_repetitive_outputs_(op_descs, + block_idx, + grad_var_to_var=None, + grad_op_id_to_fwd_op=None): """ In backward part, an variable may be the output of more than one ops. And one op may yield its multiple outputs to the same variable. @@ -500,6 +515,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None): grad_var_to_var(dict): used to build the mapping between grad var name and forward var name. Only for auto parallel. """ + _MAX_ADD_NUM_ = framework._global_flags()['FLAGS_max_inplace_grad_add'] #pending_sum_ops = [] pending_sum_ops = collections.OrderedDict() @@ -604,6 +620,7 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None): len(op_descs), var_device[var_name]) + op_descs_len = len(op_descs) # sum_op descs are sorted according to their insert position for key, value in collections.OrderedDict( reversed(list(pending_sum_ops.items()))).items(): @@ -614,12 +631,18 @@ def _addup_repetitive_outputs_(op_descs, block_idx, grad_var_to_var=None): # If not reverse, we first insert 'a' at idx 1, it becomes [0, 1, 'a', 2], and then insert 'b' at idx 2, it becomes [0, 1, 'a', 'b', 2]. idx = key for i, op in enumerate(value): + # update the mapping between fwd and bwd + target_idx = idx - 1 if idx == op_descs_len else idx + i + if grad_op_id_to_fwd_op is not None and grad_op_id_to_fwd_op.get( + op_descs[target_idx].original_id(), None) is not None: + grad_op_id_to_fwd_op[op.original_id()] = grad_op_id_to_fwd_op[ + op_descs[target_idx].original_id()] op_descs.insert(idx + i, op) return op_descs -def _remove_no_grad_branch_(op_descs, no_grad_set): +def _remove_no_grad_branch_(op_descs, no_grad_set, grad_op_id_to_fwd_op=None): """ Remove unnecessary grad ops A grad op can be removed in two cases: @@ -653,9 +676,14 @@ def _remove_no_grad_branch_(op_descs, no_grad_set): x_in = _strip_grad_suffix_(arg) # the reason should be: arg can be input of another grad op # and the op is a not-to-remove op - to_insert.append( - (_create_op_desc_("fill_zeros_like", {"X": [x_in]}, - {"Out": [arg]}, {}), idx)) + new_op_desc = _create_op_desc_("fill_zeros_like", {"X": [x_in]}, + {"Out": [arg]}, {}) + # update the mapping between fwd and bwd + if grad_op_id_to_fwd_op is not None and grad_op_id_to_fwd_op.get( + op_desc.original_id(), None) is not None: + grad_op_id_to_fwd_op[new_op_desc.original_id( + )] = grad_op_id_to_fwd_op[op_desc.original_id()] + to_insert.append((new_op_desc, idx)) list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)]) @@ -794,9 +822,13 @@ def serialize_op_decs(op_desc): return proto.__str__() -def _append_backward_ops_with_checkpoints_(block, ops, target_block, - no_grad_dict, grad_to_var, - checkpoints): +def _append_backward_ops_with_checkpoints_(block, + ops, + target_block, + no_grad_dict, + grad_to_var, + checkpoints, + grad_op_id_to_fwd_op=None): """ Create grad ops with forward ops, and insert them into given block @@ -926,12 +958,19 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block, _pretty_op_desc_(op.desc, "with_sub_block")) grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(no_grad_dict[block.idx]), []) + + # record the mapping between fwd and bwd + if grad_op_id_to_fwd_op is not None: + for op_desc in grad_op_desc: + grad_op_id_to_fwd_op[op_desc.original_id()] = op + # Set device for grad_op according to forward Op if op.desc.has_attr(device_attr_name): op_device = op.desc.attr(device_attr_name) for op_desc in grad_op_desc: op_desc._set_attr(device_attr_name, op_device) - added_descs = _add_descs_to_block(grad_op_desc, local_block) + added_descs = _add_descs_to_block(grad_op_desc, local_block, + grad_op_id_to_fwd_op) grad_op_descs.extend(added_descs) grad_to_var.update(op_grad_to_var) @@ -945,12 +984,19 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block, _pretty_op_desc_(op.desc, "with_sub_block")) grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(no_grad_dict[block.idx]), []) + + # record the mapping between fwd and bwd + if grad_op_id_to_fwd_op is not None: + for op_desc in grad_op_desc: + grad_op_id_to_fwd_op[op_desc.original_id()] = op + # Set device for grad_op according to forward Op if op.desc.has_attr(device_attr_name): op_device = op.desc.attr(device_attr_name) for op_desc in grad_op_desc: op_desc._set_attr(device_attr_name, op_device) - added_descs = _add_descs_to_block(grad_op_desc, local_block) + added_descs = _add_descs_to_block(grad_op_desc, local_block, + grad_op_id_to_fwd_op) grad_op_descs.extend(added_descs) grad_to_var.update(op_grad_to_var) @@ -984,8 +1030,10 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block, # 3.a. add ops in current recompute_segment as forward recomputation ops buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block, - vars_in_memory) - added_descs = _add_descs_to_block(ff_ops, local_block) + vars_in_memory, + grad_op_id_to_fwd_op) + added_descs = _add_descs_to_block(ff_ops, local_block, + grad_op_id_to_fwd_op) # 3.b. rename all non-checkpoint variables in recomputation ops for key in var_name_dict: @@ -999,6 +1047,12 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block, grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op_desc, cpt.to_text(no_grad_dict[block.idx]), []) + # record the mapping between fwd and bwd + if grad_op_id_to_fwd_op is not None: + for g_op_desc in grad_op_desc: + grad_op_id_to_fwd_op[g_op_desc.original_id( + )] = grad_op_id_to_fwd_op[op_desc.original_id()] + # Set device for grad_op according to forward Op if op_desc.has_attr(device_attr_name): op_device = op_desc.attr(device_attr_name) @@ -1011,11 +1065,14 @@ def _append_backward_ops_with_checkpoints_(block, ops, target_block, grad_to_var.update(op_grad_to_var) # 3.d. add sum op for repetitive_outputs - grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx) + grad_op_descs = _addup_repetitive_outputs_( + grad_op_descs, block.idx, grad_op_id_to_fwd_op=grad_op_id_to_fwd_op) # 4) remove no grad branch as it is in _remove_no_grad_branch_ grad_op_descs = _remove_no_grad_branch_(grad_op_descs, - no_grad_dict[block.idx]) - added_descs = _add_descs_to_block(grad_op_descs, target_block) + no_grad_dict[block.idx], + grad_op_id_to_fwd_op) + added_descs = _add_descs_to_block(grad_op_descs, target_block, + grad_op_id_to_fwd_op) return program_stat, checkpoints_name, vars_should_be_hold, recompute_segments @@ -1090,7 +1147,8 @@ def _append_backward_ops_(block, input_grad_names_set=None, op_path_dict=None, distop_context=None, - rename_var_map=None): + rename_var_map=None, + grad_op_id_to_fwd_op=None): """ Create all grad ops, and insert them into given block @@ -1152,9 +1210,15 @@ def _append_backward_ops_(block, pre_input_grad_names_set = copy.copy(input_grad_names_set) input_grad_names_set = None sub_block_path = op_path_dict[op._block_attr_id("sub_block")] - _append_backward_ops_(sub_block, sub_block_path, grad_sub_block, - no_grad_dict, grad_to_var, callbacks, - input_grad_names_set, op_path_dict) + _append_backward_ops_(sub_block, + sub_block_path, + grad_sub_block, + no_grad_dict, + grad_to_var, + callbacks, + input_grad_names_set, + op_path_dict, + grad_op_id_to_fwd_op=grad_op_id_to_fwd_op) input_grad_names_set = pre_input_grad_names_set program._rollback() @@ -1164,6 +1228,11 @@ def _append_backward_ops_(block, grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list) + # record the mapping between fwd and bwd + if grad_op_id_to_fwd_op is not None: + for op_desc in grad_op_desc: + grad_op_id_to_fwd_op[op_desc.original_id()] = op + # Build the mapping between the forward op and backward op (Only for auto parallel) if distop_context is not None: update_distop_context(distop_context, op_grad_to_var, @@ -1251,13 +1320,17 @@ def _append_backward_ops_(block, grad_var_to_var = distop_context.grad_var_to_var[ program._appending_grad_times] # sum parameter's gradients' var given multiple var gradient - grad_op_descs = _addup_repetitive_outputs_(grad_op_descs, block.idx, - grad_var_to_var) + grad_op_descs = _addup_repetitive_outputs_( + grad_op_descs, + block.idx, + grad_var_to_var, + grad_op_id_to_fwd_op=grad_op_id_to_fwd_op) # if all outputs of the grad op are in no_grad_set, then just remove and fill zero # if all inputs of the grad op are in no_grad_set, just remove this op grad_op_descs = _remove_no_grad_branch_(grad_op_descs, - no_grad_dict[block.idx]) + no_grad_dict[block.idx], + grad_op_id_to_fwd_op) # remove some backward ops not_need_ops = _find_not_need_ops(grad_op_descs, ops, input_grad_names_set) @@ -1585,6 +1658,9 @@ def append_backward(loss, p_g_list6 = paddle.static.append_backward(loss=avg_loss, parameter_list=all_weights, no_grad_set=set(all_weights)) """ + grad_op_id_to_fwd_op = { + } # for cuda graph usage, recording the mapping between grad op original id to fwd op + check_type(loss, 'loss', framework.Variable, 'paddle.static.append_backward') @@ -1644,7 +1720,9 @@ def append_backward(loss, grad_to_var = dict() + # pass the cuda_graph_attr to the fill_constant which generates the loss_grad op_desc = _create_loss_op_desc_(loss) + grad_op_id_to_fwd_op[op_desc.original_id()] = loss.op target_grad_block.desc.append_op().copy_from(op_desc) for block_idx in son_parent_block_idx_dict: @@ -1690,7 +1768,8 @@ def append_backward(loss, root_block, no_grad_dict, grad_to_var, - checkpoints) + checkpoints, + grad_op_id_to_fwd_op) else: _append_backward_ops_( block, # the block where forward ops are in @@ -1702,7 +1781,7 @@ def append_backward(loss, input_grad_names_set=input_grad_names_set, op_path_dict=op_path_dict, distop_context=distop_context, - ) + grad_op_id_to_fwd_op=grad_op_id_to_fwd_op) grad_info_map = dict() @@ -1722,6 +1801,12 @@ def append_backward(loss, program.current_block_idx = current_block_idx program._sync_with_cpp() + # for cuda graph, copy the cuda graph attr from forward op to backward op + for op in target_grad_block.ops: + if grad_op_id_to_fwd_op.get(op.desc.original_id(), None) is not None: + fwd_op = grad_op_id_to_fwd_op[op.desc.original_id()] + op._cuda_graph_attr = fwd_op._cuda_graph_attr + if parameter_list is not None: check_type(parameter_list, 'parameter_list', (list, tuple, set), 'fluid.backward.append_backward') diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index e0b4f8d19e8..fdd5c0b47b4 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -81,6 +81,7 @@ global_prog_seed = 0 _current_pipeline_stage = None _already_patch_eager_tensor = False _already_patch_varbase = False +_current_cuda_graph_mode = None _global_flags_ = core.globals() # Some explanation of our execution system 2022.03 @@ -2622,6 +2623,9 @@ class Operator(object): op_attrs = dict() del attrs + # attr for static mode cuda graph + self._cuda_graph_attr = _current_cuda_graph_mode + op_maker = core.op_proto_and_checker_maker if op_maker.kOpRoleAttrName() not in op_attrs: @@ -7017,6 +7021,37 @@ def device_guard(device=None): switch_device(pre_device) +def _switch_cuda_graph_mode(cuda_graph_attr): + global _current_cuda_graph_mode + pre_mode = _current_cuda_graph_mode + _current_cuda_graph_mode = cuda_graph_attr + return pre_mode + + +@signature_safe_contextmanager +def _cuda_graph_guard(cuda_graph_attr=None): + """ + + Note: + The API only supports static mode. + + A context manager that specifies the cuda_graph_mode which indicating the cuda graph capture under static mode. + + Args: + cuda_graph_attr(str|None): The cuda graph attr with the format of: + cuda_graph_capture_mode;memory_pool_id;cuda_graph_id + """ + assert not _non_static_mode( + ), "cuda_graph_guard only works under static mode" + assert core.is_compiled_with_cuda( + ), "cuda_graph_guard context can be only used when Paddle is compiled with cuda" + pre_mode = _switch_cuda_graph_mode(cuda_graph_attr) + try: + yield + finally: + _switch_cuda_graph_mode(pre_mode) + + def set_flags(flags): """ This function sets the GFlags value in Paddle. diff --git a/python/paddle/fluid/tests/unittests/test_cuda_graph_partial_graph_static.py b/python/paddle/fluid/tests/unittests/test_cuda_graph_partial_graph_static.py new file mode 100644 index 00000000000..b70be74ea92 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cuda_graph_partial_graph_static.py @@ -0,0 +1,71 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import unittest +import numpy as np +from paddle.device.cuda.graphs import wrap_cuda_graph, is_cuda_graph_supported + +paddle.enable_static() + + +class SimpleModel(nn.Layer): + + def __init__(self, in_size, out_size): + super(SimpleModel, self).__init__() + self.linear = nn.Linear(in_size, out_size) + self.dropout_1 = paddle.nn.Dropout(0.1) + self.relu = nn.ReLU() + self.dropout_2 = paddle.nn.Dropout(0.5) + self.gelu = nn.GELU() + + def forward(self, x): + x = self.linear(x) + x = self.dropout_1(x) + x = self.relu(x) + x = self.dropout_2(x) + x = self.gelu(x) + return x + + +class TestCudaGraphAttrAll(unittest.TestCase): + + def test_all_program(self): + if not is_cuda_graph_supported(): + return + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.static.program_guard(main_prog, start_prog): + model = SimpleModel(10, 20) + cuda_graph_model = wrap_cuda_graph(model) + x = paddle.static.data(shape=[3, 10], dtype='float32', name='x') + y = cuda_graph_model(x) + loss = paddle.mean(y) + opt = paddle.optimizer.SGD() + opt.minimize(loss) + block = main_prog.global_block() + for op in block.ops: + if op._cuda_graph_attr is None: + # the loss and opt are not wrapped + assert op.type in [ + 'sgd', 'reduce_mean', 'fill_constant', + 'reduce_mean_grad' + ] + else: + assert op._cuda_graph_attr == 'thread_local;0;0' + + +if __name__ == "__main__": + unittest.main() -- GitLab