tensor_info.h 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
/**
 * \file imperative/src/impl/interpreter/tensor_info.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/utils/to_string.h"

namespace mgb::imperative {

namespace interpreter::intl {

enum EvictType {
    NONE = 0,
    SWAP = 1,
    DROP = 2,
};

28 29
/*!
 * \brief an identifier to specify a component of evicted tensors
30
 *
31 32 33 34 35
 * Each component tracks the sum of the compute costs of its elements, with the
 * union of two components having the sum of each constituent cost.
 */
struct DsuNode {
    DsuNode(double _t): t(_t) {}
36

37 38 39 40 41
    std::shared_ptr<DsuNode> parent;

    bool is_root() {
        return !bool(parent);
    }
42

43 44 45
    double t;
};

46 47 48 49
struct TensorInfo;
using TensorInfoPtr = std::shared_ptr<TensorInfo>;

struct TensorInfo {
50 51
    enum Status {
        InvalidStatus, Allocated, Produced, Swapped, Dropped, Deleted,
52 53
    };

54 55 56 57 58
    uint64_t id = -1;
    std::string name;
    // Most attrs of TensorInfo, except `ptr` and `h_value`,
    // were visited read and written in main thread.
    // Lock interpreter when visiting `ptr`.
59 60
    TensorPtr ptr;
    LogicalTensorDesc desc;
61
    MemoryDesc mem_desc;
62

63 64 65
    double compute_time;
    size_t memory;
    double last_used_time;
66

67 68 69 70 71
    bool invalid = false;
    bool allow_delete = false;

    EvictType evict_type = NONE;

72 73 74 75 76 77
    // Status should be only modified in worker thread
    Status status = InvalidStatus;

    // Used by HostCompute and Memory Swap.
    // HostCompute and Swap does not happen in one thread.
    // Maybe a barrier is needed.
78 79 80 81 82
    HostTensorND h_value;

    // reserved for auto drop
    size_t pinned = 0;
    size_t recompute_times = 0;
83 84
    size_t ref_cnt = 0;
    std::shared_ptr<DsuNode> dsu_ptr;
85

86 87 88 89
    // Not reference count, inc when used as input
    size_t ptr_use_count = 0;

    // Used by `Drop` action
90
    struct ComputePath {
91
        uint64_t id;
92 93 94 95 96 97 98 99 100
        std::shared_ptr<OpDef> op;
        SmallVector<TensorInfo*> inputs;
        SmallVector<TensorInfo*> unique_inputs;
        SmallVector<TensorInfo*> outputs;

        size_t ref_cnt() {
            return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr);
        }

101
        static ComputePath* make(uint64_t id, std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs, SmallVector<TensorInfo*> outputs) {
102
            auto* path = new TensorInfo::ComputePath();
103
            path->id = id;
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
            path->op = op;
            path->inputs = inputs;
            path->outputs = outputs;
            // dedup
            SmallVector<TensorInfo*> unique_inputs = inputs;
            std::sort(unique_inputs.begin(), unique_inputs.end());
            unique_inputs.erase(std::unique(unique_inputs.begin(), unique_inputs.end()), unique_inputs.end());
            path->unique_inputs = unique_inputs;
            // attach users
            for (auto input: unique_inputs) {
                input->users.push_back(path);
            }
            // attach producer
            for (auto output: outputs) {
                output->producer = path;
            }
120 121 122 123
            // update ref_cnt
            for (auto input: inputs) {
                input->ref_cnt += outputs.size();
            }
124 125 126
            return path;
        }
    }* producer = nullptr;
127

128 129 130 131 132
    double eval_func(double cost, double free_mem, double cur_time,
                     double param_cost, double param_mem, double param_time, double param_recompute_times) {
        return pow(cost + 1e-3, param_cost) * pow(param_recompute_times, (double)recompute_times)
               / (pow((memory + free_mem) / 1024.0 / 1024.0, param_mem) * pow((double)(cur_time - last_used_time + 1e-3), param_time));
    }
133 134 135 136 137 138 139 140 141

    void pin() {
        ++pinned;
    }

    void unpin() {
        --pinned;
    }

142 143
    // returns true if producer is deleted
    bool detach_producer() {
144
        if (!producer) {
145
            return false;
146 147 148 149
        }
        auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this);
        mgb_assert(output != producer->outputs.end());
        *output = nullptr;
150
        bool deleted = false;
151 152 153 154 155
        if (producer->ref_cnt() == 0) {
            for (auto* input: producer->unique_inputs) {
                input->users.erase(std::find(input->users.begin(), input->users.end(), producer));
            }
            delete producer;
156
            deleted = true;
157 158
        }
        producer = nullptr;
159
        return deleted;
160 161
    }

162 163 164 165
    bool size_exceeds_thd(size_t thd) {
        return memory > thd;
    }

166 167 168 169 170
    SmallVector<ComputePath*> users;
};
}

}