op_def.cpp 7.8 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/impl/op_def.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12
 */

#include "megbrain/imperative/op_def.h"
13 14 15

#include <sstream>

16 17 18 19 20 21 22
#include "megbrain/imperative/ops/opr_attr.h"

#include "./op_trait.h"

namespace mgb {
namespace imperative {

M
Megvii Engine Team 已提交
23
std::shared_ptr<OpDef> OpDef::make_from_op_node(cg::OperatorNodeBase* node) {
24 25 26 27 28 29 30 31 32 33 34
    OpTrait* trait;
    trait = OpTrait::find_by_typeinfo(node->dyn_typeinfo());
    if (!trait) {
        // TODO: register `make_from_op_node` for each OperatorNode
        // instead of forwarding to OprAttr
        trait = OpTrait::find_by_typeinfo(OprAttr::typeinfo());
    }
    mgb_assert(trait);
    return trait->make_from_op_node(node);
}

35
DispatchMode OpDef::decide_dispatch_mode(
M
Megvii Engine Team 已提交
36
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
37 38 39
    return def.trait()->decide_dispatch_mode(def, inputs);
}

40
SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
M
Megvii Engine Team 已提交
41
        const OpDef& def, SmallVector<TensorPtr> inputs) {
42
    return def.trait()->apply_on_physical_tensor(def, std::move(inputs));
43 44
}

M
Megvii Engine Team 已提交
45 46 47 48
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> OpDef::
        infer_output_mem_desc(
                const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors,
                const SmallVector<MemoryDesc>& inputs_mems) {
49 50 51 52
    return def.trait()->infer_output_mem_desc(def, inputs_tensors, inputs_mems);
}

void OpDef::execute(
M
Megvii Engine Team 已提交
53 54
        const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs,
        SmallVector<TensorPtr> workspace) {
55 56 57
    def.trait()->execute(def, std::move(inputs), outputs, std::move(workspace));
}

58
void OpDef::apply_on_device_tensornd(
M
Megvii Engine Team 已提交
59 60
        const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
        SmallVector<DeviceTensorND>* outputs) {
61 62 63 64
    def.trait()->apply_on_device_tensornd(def, inputs, outputs);
    return;
}

M
Megvii Engine Team 已提交
65
VarNodeArray OpDef::apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
66 67 68
    return def.trait()->apply_on_var_node(def, inputs);
}

69
std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_fallible(
M
Megvii Engine Team 已提交
70
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
71 72 73
    return def.trait()->infer_output_attrs_fallible(def, inputs);
}

M
Megvii Engine Team 已提交
74
EncodedSubgraph OpDef::make_backward_graph(
M
Megvii Engine Team 已提交
75 76 77 78 79
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
        const SmallVector<bool>& input_requires_grad,
        const SmallVector<bool>& output_has_grad) {
    using BackwardGraphCache =
            OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
80
    thread_local BackwardGraphCache cache;
M
Megvii Engine Team 已提交
81 82 83 84
    decltype(cache)::key_t cache_key{
            const_cast<OpDef&>(def).shared_from_this(),
            inputs,
            {input_requires_grad, output_has_grad}};
85 86
    auto iter = cache.find(cache_key);
    if (iter == cache.end()) {
M
Megvii Engine Team 已提交
87 88 89 90
        iter = cache.insert({cache_key, def.trait()->make_backward_graph(
                                                def, inputs, input_requires_grad,
                                                output_has_grad)})
                       .first;
91 92
    }
    return iter->second;
93 94
}

M
Megvii Engine Team 已提交
95
std::vector<std::pair<const char*, std::string>> OpDef::props(const OpDef& def) {
96 97 98
    return def.trait()->props(def);
}

M
Megvii Engine Team 已提交
99
EncodedSubgraph OpDef::make_forward_graph(
M
Megvii Engine Team 已提交
100 101 102
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
    using ForwardGraphCache =
            OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
103
    thread_local ForwardGraphCache cache;
M
Megvii Engine Team 已提交
104 105
    decltype(cache)::key_t cache_key{
            const_cast<OpDef&>(def).shared_from_this(), inputs};
106 107
    auto iter = cache.find(cache_key);
    if (iter == cache.end()) {
M
Megvii Engine Team 已提交
108 109
        iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)})
                       .first;
110 111 112 113
    }
    return iter->second;
}

114
std::string OpDef::to_string() const {
115
    std::string builder = trait()->make_name(*this) + "{";
M
Megvii Engine Team 已提交
116
    for (auto&& [name, value] : props(*this)) {
117 118 119 120 121 122 123 124
        builder += name;
        builder += ": ";
        builder += value;
        builder += ",";
    }
    return builder + "}";
}

125 126 127 128 129 130 131 132
size_t OpDef::hash() const {
    return trait()->hash(*this);
}

bool OpDef::is_same_st(const Hashable& rhs) const {
    return trait()->is_same_st(*this, static_cast<const OpDef&>(rhs));
}

133 134 135
const OpTrait* OpDef::trait() const {
    if (!m_trait) {
        m_trait = OpTrait::find_by_typeinfo(dyn_typeinfo());
M
Megvii Engine Team 已提交
136 137 138
        mgb_throw_if(
                !m_trait, MegBrainError, "can not find op_trait by %s",
                dyn_typeinfo()->name);
139 140 141 142
    }
    return m_trait;
}

143 144 145 146 147 148 149 150 151 152 153 154 155 156
const std::string OpDef::scope() const {
    return m_scope;
}

void OpDef::set_scope(const std::string& scope) {
    m_scope = scope;
}

const std::string OpDef::make_name() const {
    if (m_scope.empty())
        return trait()->make_name(*this);
    return m_scope + "." + trait()->make_name(*this);
}

157 158 159 160 161 162 163 164 165 166 167
static thread_local OpDef::allocator_t local_allocator;

void OpDef::set_allocator(allocator_t allocator) {
    mgb_assert(!local_allocator, "allocator has been set before");
    local_allocator = allocator;
}

DeviceTensorStorage::RawStorage OpDef::allocate(CompNode device, size_t size) const {
    return local_allocator(device, size);
}

168 169 170 171
std::string Subgraph::repr() const {
    std::ostringstream buf;
    buf << "(";
    for (size_t i = 0; i < inputs.size(); ++i) {
M
Megvii Engine Team 已提交
172 173
        if (i > 0)
            buf << ", ";
174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
        buf << "%" << inputs[i];
    }
    buf << ") => {\n";
    auto fmt_const = [](size_t i, const TensorPtr& t) {
        if (t->shape().ndim == 1 && t->shape()[0] == 1) {
            auto&& v = t->get_value();
            if (v.dtype() == dtype::Float32{}) {
                return std::to_string(*v.ptr<dt_float32>());
            } else if (v.dtype() == dtype::Int32{}) {
                return std::to_string(*v.ptr<int32_t>());
            }
        }
        return std::string("%c") + std::to_string(i);
    };
    std::unordered_map<size_t, std::string> const_reps;
    for (auto&& [i, t] : constants) {
        const_reps.emplace(i, fmt_const(i, t));
    }
    for (auto& [op, ins, outs] : exprs) {
        buf << "  ";
        if (outs.size()) {
            for (size_t i = 0; i < outs.size(); ++i) {
M
Megvii Engine Team 已提交
196 197
                if (i > 0)
                    buf << ", ";
198 199 200 201 202 203 204
                buf << "%" << outs[i];
            }
            buf << " = ";
        }
        if (auto* p = op->try_cast_final<OprAttr>()) {
            buf << p->type;
        } else {
205
            buf << op->make_name();
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
        }
        for (size_t i : ins) {
            buf << " ";
            auto&& it = const_reps.find(i);
            if (it != const_reps.end()) {
                buf << it->second;
            } else {
                buf << "%" << i;
            }
        }
        buf << "\n";
    }
    buf << "  ";
    if (outputs.size()) {
        for (size_t i = 0; i < outputs.size(); ++i) {
M
Megvii Engine Team 已提交
221 222
            if (i > 0)
                buf << ", ";
223 224 225 226 227 228 229 230 231
            buf << "%" << outputs[i];
        }
    } else {
        buf << "()";
    }
    buf << "\n}\n";
    return buf.str();
}

232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
bool Subgraph::is_single() const {
    if (exprs.size() != 1) {
        return false;
    }
    auto& expr = exprs.at(0);
    return expr.inputs == inputs && expr.outputs == outputs;
}

std::shared_ptr<OpDef> Subgraph::as_single() const {
    if (is_single()) {
        return exprs.at(0).op;
    } else {
        return nullptr;
    }
}

bool Subgraph::operator==(const Subgraph& rhs) const {
    mgb_assert(false, "Not Implemented");
}

M
Megvii Engine Team 已提交
252 253
}  // namespace imperative
}  // namespace mgb
254 255

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}