op_def.cpp 7.3 KB
Newer Older
1
/**
M
Megvii Engine Team 已提交
2 3
 * \file imperative/src/impl/op_def.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
M
Megvii Engine Team 已提交
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12
 */

#include "megbrain/imperative/op_def.h"
13 14 15

#include <sstream>

16 17 18 19 20 21 22
#include "megbrain/imperative/ops/opr_attr.h"

#include "./op_trait.h"

namespace mgb {
namespace imperative {

M
Megvii Engine Team 已提交
23
std::shared_ptr<OpDef> OpDef::make_from_op_node(cg::OperatorNodeBase* node) {
24 25 26 27 28 29 30 31 32 33 34
    OpTrait* trait;
    trait = OpTrait::find_by_typeinfo(node->dyn_typeinfo());
    if (!trait) {
        // TODO: register `make_from_op_node` for each OperatorNode
        // instead of forwarding to OprAttr
        trait = OpTrait::find_by_typeinfo(OprAttr::typeinfo());
    }
    mgb_assert(trait);
    return trait->make_from_op_node(node);
}

35
DispatchMode OpDef::decide_dispatch_mode(
M
Megvii Engine Team 已提交
36
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
37 38 39
    return def.trait()->decide_dispatch_mode(def, inputs);
}

40
SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
M
Megvii Engine Team 已提交
41
        const OpDef& def, SmallVector<TensorPtr> inputs) {
42
    return def.trait()->apply_on_physical_tensor(def, std::move(inputs));
43
}
44
void OpDef::apply_on_device_tensornd(
M
Megvii Engine Team 已提交
45 46
        const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
        SmallVector<DeviceTensorND>* outputs) {
47 48 49 50
    def.trait()->apply_on_device_tensornd(def, inputs, outputs);
    return;
}

M
Megvii Engine Team 已提交
51
VarNodeArray OpDef::apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
52 53 54
    return def.trait()->apply_on_var_node(def, inputs);
}

55
std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_fallible(
M
Megvii Engine Team 已提交
56
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
57 58 59
    return def.trait()->infer_output_attrs_fallible(def, inputs);
}

M
Megvii Engine Team 已提交
60
EncodedSubgraph OpDef::make_backward_graph(
M
Megvii Engine Team 已提交
61 62 63 64 65
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
        const SmallVector<bool>& input_requires_grad,
        const SmallVector<bool>& output_has_grad) {
    using BackwardGraphCache =
            OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
66 67
    thread_local auto cache = std::make_unique<BackwardGraphCache>();
    BackwardGraphCache::key_t cache_key{
M
Megvii Engine Team 已提交
68 69 70
            const_cast<OpDef&>(def).shared_from_this(),
            inputs,
            {input_requires_grad, output_has_grad}};
71 72 73 74 75
    auto iter = cache->find(cache_key);
    if (iter == cache->end()) {
        iter = cache->insert({cache_key, def.trait()->make_backward_graph(
                                                 def, inputs, input_requires_grad,
                                                 output_has_grad)})
M
Megvii Engine Team 已提交
76
                       .first;
77 78
    }
    return iter->second;
79 80
}

M
Megvii Engine Team 已提交
81
std::vector<std::pair<const char*, std::string>> OpDef::props(const OpDef& def) {
82 83 84
    return def.trait()->props(def);
}

M
Megvii Engine Team 已提交
85
EncodedSubgraph OpDef::make_forward_graph(
M
Megvii Engine Team 已提交
86 87 88
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
    using ForwardGraphCache =
            OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
89 90
    thread_local auto cache = std::make_unique<ForwardGraphCache>();
    ForwardGraphCache::key_t cache_key{
M
Megvii Engine Team 已提交
91
            const_cast<OpDef&>(def).shared_from_this(), inputs};
92 93 94
    auto iter = cache->find(cache_key);
    if (iter == cache->end()) {
        iter = cache->insert({cache_key, def.trait()->make_forward_graph(def, inputs)})
M
Megvii Engine Team 已提交
95
                       .first;
96 97 98 99
    }
    return iter->second;
}

100
std::string OpDef::to_string() const {
101
    std::string builder = trait()->make_name(*this) + "{";
M
Megvii Engine Team 已提交
102
    for (auto&& [name, value] : props(*this)) {
103 104 105 106 107 108 109 110
        builder += name;
        builder += ": ";
        builder += value;
        builder += ",";
    }
    return builder + "}";
}

111 112 113 114 115 116 117 118
size_t OpDef::hash() const {
    return trait()->hash(*this);
}

bool OpDef::is_same_st(const Hashable& rhs) const {
    return trait()->is_same_st(*this, static_cast<const OpDef&>(rhs));
}

119 120 121
const OpTrait* OpDef::trait() const {
    if (!m_trait) {
        m_trait = OpTrait::find_by_typeinfo(dyn_typeinfo());
M
Megvii Engine Team 已提交
122 123 124
        mgb_throw_if(
                !m_trait, MegBrainError, "can not find op_trait by %s",
                dyn_typeinfo()->name);
125 126 127 128
    }
    return m_trait;
}

129 130 131 132 133 134 135 136 137 138 139 140 141 142
const std::string OpDef::scope() const {
    return m_scope;
}

void OpDef::set_scope(const std::string& scope) {
    m_scope = scope;
}

const std::string OpDef::make_name() const {
    if (m_scope.empty())
        return trait()->make_name(*this);
    return m_scope + "." + trait()->make_name(*this);
}

143 144 145 146 147 148 149 150 151 152 153
static thread_local OpDef::allocator_t local_allocator;

void OpDef::set_allocator(allocator_t allocator) {
    mgb_assert(!local_allocator, "allocator has been set before");
    local_allocator = allocator;
}

DeviceTensorStorage::RawStorage OpDef::allocate(CompNode device, size_t size) const {
    return local_allocator(device, size);
}

154 155 156 157
std::string Subgraph::repr() const {
    std::ostringstream buf;
    buf << "(";
    for (size_t i = 0; i < inputs.size(); ++i) {
M
Megvii Engine Team 已提交
158 159
        if (i > 0)
            buf << ", ";
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
        buf << "%" << inputs[i];
    }
    buf << ") => {\n";
    auto fmt_const = [](size_t i, const TensorPtr& t) {
        if (t->shape().ndim == 1 && t->shape()[0] == 1) {
            auto&& v = t->get_value();
            if (v.dtype() == dtype::Float32{}) {
                return std::to_string(*v.ptr<dt_float32>());
            } else if (v.dtype() == dtype::Int32{}) {
                return std::to_string(*v.ptr<int32_t>());
            }
        }
        return std::string("%c") + std::to_string(i);
    };
    std::unordered_map<size_t, std::string> const_reps;
    for (auto&& [i, t] : constants) {
        const_reps.emplace(i, fmt_const(i, t));
    }
    for (auto& [op, ins, outs] : exprs) {
        buf << "  ";
        if (outs.size()) {
            for (size_t i = 0; i < outs.size(); ++i) {
M
Megvii Engine Team 已提交
182 183
                if (i > 0)
                    buf << ", ";
184 185 186 187 188 189 190
                buf << "%" << outs[i];
            }
            buf << " = ";
        }
        if (auto* p = op->try_cast_final<OprAttr>()) {
            buf << p->type;
        } else {
191
            buf << op->make_name();
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
        }
        for (size_t i : ins) {
            buf << " ";
            auto&& it = const_reps.find(i);
            if (it != const_reps.end()) {
                buf << it->second;
            } else {
                buf << "%" << i;
            }
        }
        buf << "\n";
    }
    buf << "  ";
    if (outputs.size()) {
        for (size_t i = 0; i < outputs.size(); ++i) {
M
Megvii Engine Team 已提交
207 208
            if (i > 0)
                buf << ", ";
209 210 211 212 213 214 215 216 217
            buf << "%" << outputs[i];
        }
    } else {
        buf << "()";
    }
    buf << "\n}\n";
    return buf.str();
}

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
bool Subgraph::is_single() const {
    if (exprs.size() != 1) {
        return false;
    }
    auto& expr = exprs.at(0);
    return expr.inputs == inputs && expr.outputs == outputs;
}

std::shared_ptr<OpDef> Subgraph::as_single() const {
    if (is_single()) {
        return exprs.at(0).op;
    } else {
        return nullptr;
    }
}

bool Subgraph::operator==(const Subgraph& rhs) const {
    mgb_assert(false, "Not Implemented");
}

M
Megvii Engine Team 已提交
238 239
}  // namespace imperative
}  // namespace mgb
240 241

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}