tensor_info.h 4.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
/**
 * \file imperative/src/impl/interpreter/tensor_info.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/utils/to_string.h"

namespace mgb::imperative {

namespace interpreter::intl {

enum EvictType {
    NONE = 0,
    SWAP = 1,
    DROP = 2,
};

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
/*!
 * \brief an identifier to specify a component of evicted tensors
 * 
 * Each component tracks the sum of the compute costs of its elements, with the
 * union of two components having the sum of each constituent cost.
 */
struct DsuNode {
    DsuNode(double _t): t(_t) {}
    
    std::shared_ptr<DsuNode> parent;

    bool is_root() {
        return !bool(parent);
    }
    
    double t;
};

46 47 48 49 50 51 52 53 54 55 56 57
struct TensorInfo;
using TensorInfoPtr = std::shared_ptr<TensorInfo>;

struct TensorInfo {
    enum Prop {
        Device, Shape, DType, DevValue, HostValue
    };

    uint64_t id;
    TensorPtr ptr;
    LogicalTensorDesc desc;

58 59 60 61
    double compute_time;
    size_t memory;
    double last_used_time;
    
62 63 64 65 66 67 68 69 70 71 72 73
    // FIXME: broken by drop
    bool value_fetched = false;
    bool invalid = false;
    bool allow_delete = false;

    EvictType evict_type = NONE;

    HostTensorND h_value;

    // reserved for auto drop
    size_t pinned = 0;
    size_t recompute_times = 0;
74 75
    size_t ref_cnt = 0;
    std::shared_ptr<DsuNode> dsu_ptr;
76 77 78 79 80 81

    struct ComputePath {
        std::shared_ptr<OpDef> op;
        SmallVector<TensorInfo*> inputs;
        SmallVector<TensorInfo*> unique_inputs;
        SmallVector<TensorInfo*> outputs;
82
        double compute_time = 0;
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105

        size_t ref_cnt() {
            return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr);
        }

        static ComputePath* make(std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs, SmallVector<TensorInfo*> outputs) {
            auto* path = new TensorInfo::ComputePath();
            path->op = op;
            path->inputs = inputs;
            path->outputs = outputs;
            // dedup
            SmallVector<TensorInfo*> unique_inputs = inputs;
            std::sort(unique_inputs.begin(), unique_inputs.end());
            unique_inputs.erase(std::unique(unique_inputs.begin(), unique_inputs.end()), unique_inputs.end());
            path->unique_inputs = unique_inputs;
            // attach users
            for (auto input: unique_inputs) {
                input->users.push_back(path);
            }
            // attach producer
            for (auto output: outputs) {
                output->producer = path;
            }
106 107 108 109
            // update ref_cnt
            for (auto input: inputs) {
                input->ref_cnt += outputs.size();
            }
110 111 112
            return path;
        }
    }* producer = nullptr;
113 114 115 116 117 118
  
    double eval_func(double cost, double free_mem, double cur_time,
                     double param_cost, double param_mem, double param_time, double param_recompute_times) {
        return pow(cost + 1e-3, param_cost) * pow(param_recompute_times, (double)recompute_times)
               / (pow((memory + free_mem) / 1024.0 / 1024.0, param_mem) * pow((double)(cur_time - last_used_time + 1e-3), param_time));
    }
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143

    void pin() {
        ++pinned;
    }

    void unpin() {
        --pinned;
    }

    void detach_producer() {
        if (!producer) {
            return;
        }
        auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this);
        mgb_assert(output != producer->outputs.end());
        *output = nullptr;
        if (producer->ref_cnt() == 0) {
            for (auto* input: producer->unique_inputs) {
                input->users.erase(std::find(input->users.begin(), input->users.end(), producer));
            }
            delete producer;
        }
        producer = nullptr;
    }

144 145 146 147
    bool size_exceeds_thd(size_t thd) {
        return memory > thd;
    }

148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
    SmallVector<ComputePath*> users;
};
}

template <>
struct ToStringTrait<interpreter::intl::TensorInfo::Prop>{
    using TensorInfo = interpreter::intl::TensorInfo;

    std::string operator()(TensorInfo::Prop prop) const {
        switch(prop) {
        case TensorInfo::DType:
            return "dtype";
        case TensorInfo::DevValue:
            return "dev_value";
        case TensorInfo::Device:
            return "device";
        case TensorInfo::HostValue:
            return "host_value";
        case TensorInfo::Shape:
            return "shape";
        default:
            return "unknown";
        }
    }
};

}