graphs.py 18.9 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import os
16
import paddle
17 18
from paddle.fluid import core
from paddle.fluid.layers.utils import _hash_with_id
19
from paddle.fluid.core import is_compiled_with_cuda, is_compiled_with_rocm, CUDAPlace
20
import warnings
21 22 23

if is_compiled_with_cuda() and not is_compiled_with_rocm():
    from paddle.fluid.core import CUDAGraph as CoreCUDAGraph
24 25 26

    def is_cuda_graph_supported():
        return True
27
else:
S
sneaxiy 已提交
28 29
    CoreCUDAGraph = None

30 31 32 33 34
    def is_cuda_graph_supported():
        return False


ALL_MODES = ["global", "thread_local", "relaxed"]
35
cuda_graph_id = 0
36

S
sneaxiy 已提交
37 38

class CUDAGraph:
39

S
sneaxiy 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    def __init__(self, place=None, mode="thread_local"):
        assert CoreCUDAGraph is not None, "CUDA Graph is only supported on PaddlePaddle compiled with NVIDIA GPU."

        self._graph = None
        if place is None:
            device_id = int(os.environ.get('FLAGS_selected_gpus', 0))
            place = CUDAPlace(device_id)
        self._place = place
        assert mode in ALL_MODES
        self._mode = ALL_MODES.index(mode)

    def capture_begin(self):
        CoreCUDAGraph.begin_capture(self._place, self._mode)

    def capture_end(self):
        self._graph = CoreCUDAGraph.end_capture()

    def replay(self):
        self._graph.replay()

    def reset(self):
        self._graph.reset()

    def print_to_dot_files(self, dirname, flags=None):
        if not isinstance(dirname, (str, bytes)):
            dirname = dirname.name
        os.makedirs(name=dirname, exist_ok=True)
        assert os.path.isdir(
            dirname), "The dirname {} should be a directory".format(dirname)
        if flags is None:
70
            flags = 2047  # only all information. It can be any integer inside [1, 2048)
S
sneaxiy 已提交
71
        self._graph.print_to_dot_files(dirname, flags)
72 73 74 75


def wrap_cuda_graph(function, mode="thread_local", memory_pool="default"):
    assert mode in ALL_MODES
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
    if not paddle.in_dynamic_mode():
        # static mode
        from paddle.fluid.framework import _cuda_graph_guard
        global cuda_graph_id
        graph_id = str(cuda_graph_id)
        cuda_graph_id += 1
        if memory_pool == 'default':
            memory_pool_id = 0
        elif memory_pool == 'new':
            memory_pool_id = CoreCUDAGraph.gen_new_memory_pool_id()
        else:
            raise ValueError(
                "memory_pool should be one of default or new under static mode, but got",
                memory_pool)
        return _cuda_graph_guard(
            mode + ';' + str(memory_pool_id) + ';' +
            graph_id)(lambda *args, **kwargs: function(*args, **kwargs))

94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
    from paddle.jit import to_static
    from paddle.nn import Layer
    new_function = to_static(function)
    if isinstance(function, Layer):
        mock_func = new_function.forward
    else:
        mock_func = new_function
    mock_func._cuda_graph_capture_mode = mode
    if memory_pool == "default":
        mock_func._cuda_graph_pool_id = 0
    elif memory_pool == "new":
        mock_func._cuda_graph_pool_id = CoreCUDAGraph.gen_new_memory_pool_id()
    else:
        if isinstance(memory_pool, Layer):
            mock_func._cuda_graph_pool_id = memory_pool.forward._cuda_graph_pool_id
        else:
            mock_func._cuda_graph_pool_id = memory_pool._cuda_graph_pool_id
    return new_function
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175


def copy_var_desc(dst, src):
    """
    copy var desc from src to dst

    :param dst: framework.VarDesc(cpp), dst var desc, cpp VarDesc instance
    :param src: framework.VarDesc(cpp), src var desc, cpp VarDesc instance
    :return: no return
    """
    dst.set_shape(src.shape)
    dst.set_dtype(src.dtype)
    dst.set_lod_level(src.lod_level)
    dst.set_type(src.type)
    dst.set_persistable(src.persistable)
    dst.set_is_parameter(src.is_parameter)
    dst.set_stop_gradient(src.stop_gradient)


def all_inputs_of_later_op(block, begin_idx):
    """
    find all inputs of ops after an idx, used to determine the logical output of a cuda graph section

    :param block: framework.Block, the original block
    :param begin_idx: int, from which idx (not include) to find the later ins
    :return: a list of inputs names for all ops behind begin_idx
    """
    ins = []
    for idx, op in enumerate(block.ops):
        if idx <= begin_idx:
            continue
        for in_name in op.input_arg_names:
            ins.append(in_name)
    return list(set(ins))


def construct_program_and_find_ins_outs(section, origin_program, section_idx):
    """
    1. Construct a new program for corresponding section
    2. Find all the logical inputs and outputs of a program section

    :param section: list, one cuda graph section, list of ops
    :param origin_program: framework.Program, origin program
    :param section_idx: list, the section ops' idx corresponding to the cuda graph section, a list of idx
    :return: a new program for the cuda graph section
             the logical ins and outs of the cuda graph section
    """
    program = paddle.static.Program()
    block = program.global_block()
    origin_block = origin_program.global_block()
    ins = []
    outs = []
    op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
    later_ins = all_inputs_of_later_op(origin_block, section_idx[-1])

    for op in section:
        for in_name in op.input_arg_names:
            var = origin_block.var(in_name)
            new_var_desc = block.desc.var(var.name.encode("ascii"))
            copy_var_desc(new_var_desc, var)
            if outs.count(in_name) == 0 and ins.count(in_name) == 0:
                # This in var is generated from op outside this section
                # Only record once for same input
                ins.append(in_name)
176
            elif later_ins.count(in_name) == 0 and outs.count(in_name) > 0:
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
                # this is var is generated from op inside this section, and only will be used inside this section
                outs.remove(in_name)
        for out_name in op.output_arg_names:
            var = origin_block.var(out_name)
            new_var_desc = block.desc.var(var.name.encode("ascii"))
            copy_var_desc(new_var_desc, var)
            # for every output, we add it to the section's outs
            if outs.count(out_name) == 0:
                # Only record one out var even if it will be generated by multi ops.
                # For scenario like this:
                # A = op1(a)
                # A = op2(b)
                # B = op3(A)
                outs.append(out_name)
        new_op_desc = block.desc.append_op()
        new_op_desc.copy_from(op.desc)
        new_op_desc._set_attr(op_role_attr_name, op.attr(op_role_attr_name))

    program._sync_with_cpp()

    return program, [ins, outs]


def get_cuda_graph_sections(program):
    """
    get all sections that should run under cuda graph and the corresponding idx

    :param program: framework.Program, the original program
    :return: A list of cuda graph sections and the corresponding ops' idx in the block.
             The program is under is test or not.
    """
    block = program.global_block()
    cuda_graph_sections = []  # record all ops in every cuda graph sections
    sections_idx = []  # idx of all ops in every cuda graph sections
    is_test = False  # will be set to True is any op's 'is_test' attr is True

    # ops and it's idx between cuda graph wrapped op, may belong to a section
    internal_section = []
    internal_idx = []

    current_section = []  # current recording cuda graph sections
    current_idx = []  # current recording cuda graph ops' idx
    current_cuda_graph_id = -1  # current recording cuda graph id
    op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
    loss_op_role = int(core.op_proto_and_checker_maker.OpRole.Loss)
    backward_op_role = int(core.op_proto_and_checker_maker.OpRole.Backward)
    loss_grad_op_role = loss_op_role | backward_op_role

    for idx, op in enumerate(block.ops):
        if op.type == 'conditional_block' or op.type == 'while':
            assert op._cuda_graph_attr is None, "Cuda graph not support conditional block op and while op."
        if op.has_attr('is_test') and op.attr('is_test'):
            is_test = True
        # find cuda graph sections
        if op._cuda_graph_attr is not None:
            assert isinstance(op._cuda_graph_attr,
                              str), "cuda_graph_attr should be a str"
            cuda_graph_attrs = op._cuda_graph_attr.split(';')
            assert len(cuda_graph_attrs) == 3, "cuda graph attr should have three fields: " \
                                               "cuda graph mode, cuda graph memory pool id, cuda graph id"
            local_cuda_graph_id = int(cuda_graph_attrs[2])
            if local_cuda_graph_id == current_cuda_graph_id:
                if len(internal_section) > 0:
                    assert len(internal_section) == len(
                        internal_idx
                    ), "len of internal section should be equal with len of internal idx"
                    for internal_op in internal_section:
                        loss_related = (int(internal_op.attr(op_role_attr_name))
                                        == loss_op_role) or int(
                                            (internal_op.attr(op_role_attr_name)
                                             ) == loss_grad_op_role)
                        sub_block_related = (op.type == 'conditional_block'
                                             or op.type == 'while')
                        if loss_related or sub_block_related:
251
                            # If loss_related is True
252 253 254 255 256
                            # The internal section contains loss related ops,
                            # although these ops are between two cuda graph sections with same graph id,
                            # they belong to none of these two sections.
                            # The loss related op should be wrapped by user explicitly.

257
                            # If sub_block_related is True
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
                            # The internal section contains while op or conditional block op.
                            # These two ops are not supported by cuda graph. Won't extend the section.
                            internal_section = []
                            internal_idx = []
                            # Beside clear the internal section, a new cuda graph section should be recorded
                            assert len(current_section) == len(current_idx), \
                                "num of section's op is not equal with the idx"
                            if len(current_section) > 0:
                                # store previous section
                                cuda_graph_sections.append(current_section)
                                sections_idx.append(current_idx)
                            current_section = []
                            current_idx = []
                            break
                    # some ops inserted by some optimizer, should be added to current section
                    for i in range(len(internal_section)):
                        current_section.append(internal_section[i])
                        current_idx.append(internal_idx[i])
                internal_section = []
277
                internal_idx = []
278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
                current_section.append(op)
                current_idx.append(idx)
            else:
                # current graph id is different with previous, start a new section of cuda graph
                # internal ops and idx belong to no section, just clear it
                internal_section = []
                internal_idx = []
                current_cuda_graph_id = local_cuda_graph_id  # start record a new section
                assert len(current_section) == len(
                    current_idx
                ), "num of section's op is not equal with num of idx"
                if len(current_section) > 0:
                    # store previous section
                    cuda_graph_sections.append(current_section)
                    sections_idx.append(current_idx)
                current_section = [op]
                current_idx = [idx]
        else:
            # recode ops which cuda_graph_attr is None, may belong to a section
            internal_section.append(op)
            internal_idx.append(idx)

    # handle the last section
    assert len(current_section) == len(
        current_idx), "num of section's op is not equal with num of idx"
    if len(current_section) > 0:
        # store previous section
        cuda_graph_sections.append(current_section)
        sections_idx.append(current_idx)

    return cuda_graph_sections, sections_idx, is_test


def replace_cuda_graph_section(ins_and_outs, section_program, section_idx,
                               origin_program, cuda_graph_section, order,
                               is_test):
    """
    Use section_program and ins_and_outs to initialize a run_program_op,
    and replace the section_idx marks ops in the origin program.

    :param ins_and_outs: list, the logical ins and outs of the section program
    :param section_program: framework.Program, the partial program need to run under cuda graph
    :param section_idx: list, the idx need to be removed from origin program
    :param origin_program: framework.Program, the origin program
    :param cuda_graph_section: list, the ops in current sections, used to get the mode, memory pool id and is_test
    :param order: int, the order of current section, used to create unique cuda graph var
    :param is_test: bool, the program is running under is_test or not
    :return: no return
    """
    ins = ins_and_outs[0]
    outs = ins_and_outs[1]
    insert_idx = section_idx[0]
    origin_block = origin_program.global_block()

    for idx in reversed(section_idx):
        # remove all cuda graph marked ops from origin block
        origin_block._remove_op(idx, sync=False)

    mode = None
    memory_pool_id = None

    for op in cuda_graph_section:
        # find the cuda graph mode and memory pool id, determine is test or not
        if op._cuda_graph_attr is not None:
            attrs = op._cuda_graph_attr.split(';')
            mode = attrs[0]
            memory_pool_id = int(attrs[1])
            break

    assert mode is not None and memory_pool_id is not None, \
        "mode and memory pool id should be specified in cuda graph attr"

    cuda_graph_var = origin_block.create_var(
        name="cuda_graph_" + str(order),
        type=core.VarDesc.VarType.RAW,
        persistable=True,
        stop_gradient=True,
    )

    # not used for the run_program_op, just needed by the op, but won't be used
    out_scope_var = origin_block.create_var(
        name="program_out_scope_" + str(order),
        type=core.VarDesc.VarType.STEP_SCOPES,
        persistable=True,
        stop_gradient=True,
    )

    program_id = _hash_with_id(section_program, ins_and_outs)

    # insert the run_program_op into the block
    origin_block._insert_op(insert_idx,
                            type='run_program',
                            inputs={'X': ins},
                            outputs={
                                'Out': outs,
                                'OutScope': out_scope_var,
                                'CUDAGraph': cuda_graph_var
                            },
                            attrs={
                                'global_block':
                                section_program.global_block(),
                                'start_op_index':
                                0,
                                'end_op_index':
                                len(section_program.global_block().ops),
                                'is_test':
                                is_test,
                                'program_id':
                                program_id,
                                'cuda_graph_capture_mode':
                                mode,
                                'cuda_graph_pool_id':
                                memory_pool_id,
                            })


def cuda_graph_transform(program):
    """
    replace the ops marked with cuda_graph_attr to run_program_op to use cuda graph

    :param program: framework.Program, the program to be transformed
    :return: the cuda graph section program, user should hold these programs!
    """

    if len(program.blocks) > 1:
        # some sub blocks may be inserted by optimizer but will not use during training, just warn here
        warnings.warn(
            "Sub block(s) has been detected in the program. "
            "Cuda graph not support op with sub block, and it will only handle the global block."
        )

    # step 1: get all cuda graph sections.
    # A cuda graph section contains all ops marked with same cuda graph id and
    # some ops inserted by some optimizers (amp, sharding for example) between ops with same id.
    cuda_graph_sections, sections_idx, is_test = get_cuda_graph_sections(
        program)
    assert len(cuda_graph_sections) == len(sections_idx), \
        "num of cuda graph sections is not equal with num of idx sections"

    # step 2: construct new program for each section and find inputs and outputs of each section.
    # The inputs are variables generated outside the section but will be used by this section.
    # The outputs are variables generated by this section and will be used after the end of the section.
    ins_and_outs = []
    section_programs = []
    for i in range(len(cuda_graph_sections)):
        # creating new program for current section
        section_program, ins_outs = construct_program_and_find_ins_outs(
            cuda_graph_sections[i], program, sections_idx[i])
        ins_and_outs.append(ins_outs)
        section_programs.append(section_program)
    assert len(section_programs) == len(cuda_graph_sections), \
        "the num of cuda graph sections should be equal with the num of new program"

    # step 3: replace the ops in original program with run_program_op.
    # Will remove all ops in the section from origin program, and use run_program_op to replace them.
    for i in reversed(range(len(cuda_graph_sections))):
        # carry out the replacement in reversed order, to keep the previous idx intact
        replace_cuda_graph_section(ins_and_outs[i],
                                   section_programs[i],
                                   sections_idx[i],
                                   program,
                                   cuda_graph_sections[i],
                                   order=i,
                                   is_test=is_test)

    # NOTE: user should hold these program, for now just return these program back to caller
    return section_programs