tensor_info.h 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
/**
 * \file imperative/src/impl/interpreter/tensor_info.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "megbrain/imperative/physical_tensor.h"
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/utils/to_string.h"

namespace mgb::imperative {

namespace interpreter::intl {

enum EvictType {
    NONE = 0,
    SWAP = 1,
    DROP = 2,
};

struct TensorInfo;
using TensorInfoPtr = std::shared_ptr<TensorInfo>;

struct TensorInfo {
    enum Prop {
        Device, Shape, DType, DevValue, HostValue
    };

    uint64_t id;
    TensorPtr ptr;
    LogicalTensorDesc desc;

    // FIXME: broken by drop
    bool value_fetched = false;
    bool invalid = false;
    bool allow_delete = false;

    EvictType evict_type = NONE;

    HostTensorND h_value;

    // reserved for auto drop
    size_t pinned = 0;
    size_t recompute_times = 0;

    struct ComputePath {
        std::shared_ptr<OpDef> op;
        SmallVector<TensorInfo*> inputs;
        SmallVector<TensorInfo*> unique_inputs;
        SmallVector<TensorInfo*> outputs;

        size_t ref_cnt() {
            return outputs.size() - std::count(outputs.begin(), outputs.end(), nullptr);
        }

        static ComputePath* make(std::shared_ptr<OpDef> op, SmallVector<TensorInfo*> inputs, SmallVector<TensorInfo*> outputs) {
            auto* path = new TensorInfo::ComputePath();
            path->op = op;
            path->inputs = inputs;
            path->outputs = outputs;
            // dedup
            SmallVector<TensorInfo*> unique_inputs = inputs;
            std::sort(unique_inputs.begin(), unique_inputs.end());
            unique_inputs.erase(std::unique(unique_inputs.begin(), unique_inputs.end()), unique_inputs.end());
            path->unique_inputs = unique_inputs;
            // attach users
            for (auto input: unique_inputs) {
                input->users.push_back(path);
            }
            // attach producer
            for (auto output: outputs) {
                output->producer = path;
            }
            return path;
        }
    }* producer = nullptr;

    void pin() {
        ++pinned;
    }

    void unpin() {
        --pinned;
    }

    void detach_producer() {
        if (!producer) {
            return;
        }
        auto output = std::find(producer->outputs.begin(), producer->outputs.end(), this);
        mgb_assert(output != producer->outputs.end());
        *output = nullptr;
        if (producer->ref_cnt() == 0) {
            for (auto* input: producer->unique_inputs) {
                input->users.erase(std::find(input->users.begin(), input->users.end(), producer));
            }
            delete producer;
        }
        producer = nullptr;
    }

    SmallVector<ComputePath*> users;
};
}

template <>
struct ToStringTrait<interpreter::intl::TensorInfo::Prop>{
    using TensorInfo = interpreter::intl::TensorInfo;

    std::string operator()(TensorInfo::Prop prop) const {
        switch(prop) {
        case TensorInfo::DType:
            return "dtype";
        case TensorInfo::DevValue:
            return "dev_value";
        case TensorInfo::Device:
            return "device";
        case TensorInfo::HostValue:
            return "host_value";
        case TensorInfo::Shape:
            return "shape";
        default:
            return "unknown";
        }
    }
};

}