memory_analysis.py 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import core
import numpy as np


def get_var_and_memory_size(block, var_name, batch_size=None):
    var = block._find_var_recursive(var_name)
    assert var is not None, "Variable {} cannot be found".format(var_name)
    assert var.type == core.VarDesc.VarType.LOD_TENSOR, "Variable {} is not Tensor".format(
        var_name)
    shape = list(var.shape)
    if not shape:
        return var, 0

    has_none = False
    for i, s in enumerate(shape):
        if s is None or s < 0:
            assert not has_none
            shape[i] = batch_size
            has_none = True
    assert all(
        [s >= 0 for s in shape]), "shape {} is not deterministic".format(shape)
    mem_size = int(np.prod(shape)) * core.size_of_dtype(var.dtype)
    return var, mem_size


def pre_allocate_memory(size, place):
    t = core.LoDTensor()
    t._set_dims([size])
    t._mutable_data(place, core.VarDesc.VarType.INT8)
    del t


# NOTE: does not consider inplace yet. 
def get_max_memory_info(program, batch_size=None):
    assert program.num_blocks == 1, "only support to analysis program with only one block"
    cur_tmp_mem = 0
    max_tmp_mem = 0
    max_persistable_mem = 0
    visited_vars = set()
    alived_vars = []

    block = program.global_block()
    gc_vars = core._get_eager_deletion_vars(program.desc, [])[0]
    for i, op in enumerate(block.ops):
        var_names = op.input_arg_names + op.output_arg_names
        for var_name in var_names:
            if var_name in visited_vars:
                continue
            visited_vars.add(var_name)
            var, mem_size = get_var_and_memory_size(block, var_name, batch_size)
            if var.persistable:
                max_persistable_mem += mem_size
            else:
                cur_tmp_mem += mem_size
                max_tmp_mem = max(max_tmp_mem, cur_tmp_mem)

        cur_gc_vars = gc_vars[i]
        for var_name in var_names:
            if var_name not in cur_gc_vars:
                continue
            _, mem_size = get_var_and_memory_size(block, var_name, batch_size)
            cur_tmp_mem -= mem_size
    return max_tmp_mem, max_persistable_mem