op_def.cpp 7.7 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/impl/op_def.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12
 */

#include "megbrain/imperative/op_def.h"
13 14 15

#include <sstream>

16
#include "megbrain/imperative/ops/opr_attr.h"
17
#include "megbrain/imperative/resource_manager.h"
18 19 20 21 22 23

#include "./op_trait.h"

namespace mgb {
namespace imperative {

M
Megvii Engine Team 已提交
24
std::shared_ptr<OpDef> OpDef::make_from_op_node(cg::OperatorNodeBase* node) {
25 26 27 28 29 30 31 32 33 34 35
    OpTrait* trait;
    trait = OpTrait::find_by_typeinfo(node->dyn_typeinfo());
    if (!trait) {
        // TODO: register `make_from_op_node` for each OperatorNode
        // instead of forwarding to OprAttr
        trait = OpTrait::find_by_typeinfo(OprAttr::typeinfo());
    }
    mgb_assert(trait);
    return trait->make_from_op_node(node);
}

36
DispatchMode OpDef::decide_dispatch_mode(
M
Megvii Engine Team 已提交
37
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
38 39 40
    return def.trait()->decide_dispatch_mode(def, inputs);
}

41
SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
42
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
43 44 45
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    return def.trait()->apply_on_physical_tensor(
            def, std::move(inputs), output_descs, validated);
46
}
47
void OpDef::apply_on_device_tensornd(
M
Megvii Engine Team 已提交
48 49
        const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
        SmallVector<DeviceTensorND>* outputs) {
50 51 52 53
    def.trait()->apply_on_device_tensornd(def, inputs, outputs);
    return;
}

M
Megvii Engine Team 已提交
54
VarNodeArray OpDef::apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
55 56 57
    return def.trait()->apply_on_var_node(def, inputs);
}

58
std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_fallible(
M
Megvii Engine Team 已提交
59
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
60 61 62
    return def.trait()->infer_output_attrs_fallible(def, inputs);
}

63 64 65 66 67
SmallVector<VarNode::LayoutConstraintCallback> OpDef::get_input_layout_constraint(
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
    return def.trait()->get_input_layout_constraint(def, inputs);
}

M
Megvii Engine Team 已提交
68
EncodedSubgraph OpDef::make_backward_graph(
M
Megvii Engine Team 已提交
69 70 71 72 73
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
        const SmallVector<bool>& input_requires_grad,
        const SmallVector<bool>& output_has_grad) {
    using BackwardGraphCache =
            OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
74
    thread_local auto& cache = *ResourceManager::create_local<BackwardGraphCache>();
75
    BackwardGraphCache::key_t cache_key{
M
Megvii Engine Team 已提交
76 77 78
            const_cast<OpDef&>(def).shared_from_this(),
            inputs,
            {input_requires_grad, output_has_grad}};
79 80 81 82 83
    auto iter = cache.find(cache_key);
    if (iter == cache.end()) {
        iter = cache.insert({cache_key, def.trait()->make_backward_graph(
                                                def, inputs, input_requires_grad,
                                                output_has_grad)})
M
Megvii Engine Team 已提交
84
                       .first;
85 86
    }
    return iter->second;
87 88
}

M
Megvii Engine Team 已提交
89
std::vector<std::pair<const char*, std::string>> OpDef::props(const OpDef& def) {
90 91 92
    return def.trait()->props(def);
}

M
Megvii Engine Team 已提交
93
EncodedSubgraph OpDef::make_forward_graph(
M
Megvii Engine Team 已提交
94 95 96
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
    using ForwardGraphCache =
            OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
97
    thread_local auto& cache = *ResourceManager::create_local<ForwardGraphCache>();
98
    ForwardGraphCache::key_t cache_key{
M
Megvii Engine Team 已提交
99
            const_cast<OpDef&>(def).shared_from_this(), inputs};
100 101 102
    auto iter = cache.find(cache_key);
    if (iter == cache.end()) {
        iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)})
M
Megvii Engine Team 已提交
103
                       .first;
104 105 106 107
    }
    return iter->second;
}

108
std::string OpDef::to_string() const {
109 110
    std::string builder = trait()->name;
    builder += "{";
M
Megvii Engine Team 已提交
111
    for (auto&& [name, value] : props(*this)) {
112 113 114 115 116 117 118 119
        builder += name;
        builder += ": ";
        builder += value;
        builder += ",";
    }
    return builder + "}";
}

120 121 122 123 124 125 126 127
size_t OpDef::hash() const {
    return trait()->hash(*this);
}

bool OpDef::is_same_st(const Hashable& rhs) const {
    return trait()->is_same_st(*this, static_cast<const OpDef&>(rhs));
}

128 129 130
const OpTrait* OpDef::trait() const {
    if (!m_trait) {
        m_trait = OpTrait::find_by_typeinfo(dyn_typeinfo());
M
Megvii Engine Team 已提交
131 132 133
        mgb_throw_if(
                !m_trait, MegBrainError, "can not find op_trait by %s",
                dyn_typeinfo()->name);
134 135 136 137
    }
    return m_trait;
}

138 139 140 141 142 143 144 145 146 147 148 149 150 151
const std::string OpDef::scope() const {
    return m_scope;
}

void OpDef::set_scope(const std::string& scope) {
    m_scope = scope;
}

const std::string OpDef::make_name() const {
    if (m_scope.empty())
        return trait()->make_name(*this);
    return m_scope + "." + trait()->make_name(*this);
}

152 153 154 155 156 157 158 159 160 161 162
static thread_local OpDef::allocator_t local_allocator;

void OpDef::set_allocator(allocator_t allocator) {
    mgb_assert(!local_allocator, "allocator has been set before");
    local_allocator = allocator;
}

DeviceTensorStorage::RawStorage OpDef::allocate(CompNode device, size_t size) const {
    return local_allocator(device, size);
}

163 164 165 166
std::string Subgraph::repr() const {
    std::ostringstream buf;
    buf << "(";
    for (size_t i = 0; i < inputs.size(); ++i) {
M
Megvii Engine Team 已提交
167 168
        if (i > 0)
            buf << ", ";
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
        buf << "%" << inputs[i];
    }
    buf << ") => {\n";
    auto fmt_const = [](size_t i, const TensorPtr& t) {
        if (t->shape().ndim == 1 && t->shape()[0] == 1) {
            auto&& v = t->get_value();
            if (v.dtype() == dtype::Float32{}) {
                return std::to_string(*v.ptr<dt_float32>());
            } else if (v.dtype() == dtype::Int32{}) {
                return std::to_string(*v.ptr<int32_t>());
            }
        }
        return std::string("%c") + std::to_string(i);
    };
    std::unordered_map<size_t, std::string> const_reps;
    for (auto&& [i, t] : constants) {
        const_reps.emplace(i, fmt_const(i, t));
    }
    for (auto& [op, ins, outs] : exprs) {
        buf << "  ";
        if (outs.size()) {
            for (size_t i = 0; i < outs.size(); ++i) {
M
Megvii Engine Team 已提交
191 192
                if (i > 0)
                    buf << ", ";
193 194 195 196 197 198 199
                buf << "%" << outs[i];
            }
            buf << " = ";
        }
        if (auto* p = op->try_cast_final<OprAttr>()) {
            buf << p->type;
        } else {
200
            buf << op->to_string();
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
        }
        for (size_t i : ins) {
            buf << " ";
            auto&& it = const_reps.find(i);
            if (it != const_reps.end()) {
                buf << it->second;
            } else {
                buf << "%" << i;
            }
        }
        buf << "\n";
    }
    buf << "  ";
    if (outputs.size()) {
        for (size_t i = 0; i < outputs.size(); ++i) {
M
Megvii Engine Team 已提交
216 217
            if (i > 0)
                buf << ", ";
218 219 220 221 222 223 224 225 226
            buf << "%" << outputs[i];
        }
    } else {
        buf << "()";
    }
    buf << "\n}\n";
    return buf.str();
}

227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
bool Subgraph::is_single() const {
    if (exprs.size() != 1) {
        return false;
    }
    auto& expr = exprs.at(0);
    return expr.inputs == inputs && expr.outputs == outputs;
}

std::shared_ptr<OpDef> Subgraph::as_single() const {
    if (is_single()) {
        return exprs.at(0).op;
    } else {
        return nullptr;
    }
}

bool Subgraph::operator==(const Subgraph& rhs) const {
    mgb_assert(false, "Not Implemented");
}

M
Megvii Engine Team 已提交
247 248
}  // namespace imperative
}  // namespace mgb
249 250

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}