op_def.cpp 7.4 KB
Newer Older
1
#include "megbrain/imperative/op_def.h"
2 3 4

#include <sstream>

5
#include "megbrain/imperative/ops/opr_attr.h"
6
#include "megbrain/imperative/resource_manager.h"
7 8 9 10 11 12

#include "./op_trait.h"

namespace mgb {
namespace imperative {

M
Megvii Engine Team 已提交
13
std::shared_ptr<OpDef> OpDef::make_from_op_node(cg::OperatorNodeBase* node) {
14 15 16 17 18 19 20 21 22 23 24
    OpTrait* trait;
    trait = OpTrait::find_by_typeinfo(node->dyn_typeinfo());
    if (!trait) {
        // TODO: register `make_from_op_node` for each OperatorNode
        // instead of forwarding to OprAttr
        trait = OpTrait::find_by_typeinfo(OprAttr::typeinfo());
    }
    mgb_assert(trait);
    return trait->make_from_op_node(node);
}

25
DispatchMode OpDef::decide_dispatch_mode(
M
Megvii Engine Team 已提交
26
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
27 28 29
    return def.trait()->decide_dispatch_mode(def, inputs);
}

30
SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
31
        const OpDef& def, const SmallVector<TensorPtr>& inputs,
32 33 34
        SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
    return def.trait()->apply_on_physical_tensor(
            def, std::move(inputs), output_descs, validated);
35
}
36
void OpDef::apply_on_device_tensornd(
M
Megvii Engine Team 已提交
37 38
        const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
        SmallVector<DeviceTensorND>* outputs) {
39 40 41 42
    def.trait()->apply_on_device_tensornd(def, inputs, outputs);
    return;
}

M
Megvii Engine Team 已提交
43
VarNodeArray OpDef::apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
44 45 46
    return def.trait()->apply_on_var_node(def, inputs);
}

47
std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_fallible(
M
Megvii Engine Team 已提交
48
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
49 50 51
    return def.trait()->infer_output_attrs_fallible(def, inputs);
}

52 53 54 55 56
SmallVector<VarNode::LayoutConstraintCallback> OpDef::get_input_layout_constraint(
        const OpDef& def, const SmallVector<TensorPtr>& inputs) {
    return def.trait()->get_input_layout_constraint(def, inputs);
}

M
Megvii Engine Team 已提交
57
EncodedSubgraph OpDef::make_backward_graph(
M
Megvii Engine Team 已提交
58 59 60 61 62
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
        const SmallVector<bool>& input_requires_grad,
        const SmallVector<bool>& output_has_grad) {
    using BackwardGraphCache =
            OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
63
    thread_local auto& cache = *ResourceManager::create_local<BackwardGraphCache>();
64
    BackwardGraphCache::key_t cache_key{
M
Megvii Engine Team 已提交
65 66 67
            const_cast<OpDef&>(def).shared_from_this(),
            inputs,
            {input_requires_grad, output_has_grad}};
68 69 70 71 72
    auto iter = cache.find(cache_key);
    if (iter == cache.end()) {
        iter = cache.insert({cache_key, def.trait()->make_backward_graph(
                                                def, inputs, input_requires_grad,
                                                output_has_grad)})
M
Megvii Engine Team 已提交
73
                       .first;
74 75
    }
    return iter->second;
76 77
}

M
Megvii Engine Team 已提交
78
std::vector<std::pair<const char*, std::string>> OpDef::props(const OpDef& def) {
79 80 81
    return def.trait()->props(def);
}

M
Megvii Engine Team 已提交
82
EncodedSubgraph OpDef::make_forward_graph(
M
Megvii Engine Team 已提交
83 84 85
        const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
    using ForwardGraphCache =
            OpMethResultCache<EncodedSubgraph, SmallVector<bool>, SmallVector<bool>>;
86
    thread_local auto& cache = *ResourceManager::create_local<ForwardGraphCache>();
87
    ForwardGraphCache::key_t cache_key{
M
Megvii Engine Team 已提交
88
            const_cast<OpDef&>(def).shared_from_this(), inputs};
89 90 91
    auto iter = cache.find(cache_key);
    if (iter == cache.end()) {
        iter = cache.insert({cache_key, def.trait()->make_forward_graph(def, inputs)})
M
Megvii Engine Team 已提交
92
                       .first;
93 94 95 96
    }
    return iter->second;
}

97
std::string OpDef::to_string() const {
98 99
    std::string builder = trait()->name;
    builder += "{";
M
Megvii Engine Team 已提交
100
    for (auto&& [name, value] : props(*this)) {
101 102 103 104 105 106 107 108
        builder += name;
        builder += ": ";
        builder += value;
        builder += ",";
    }
    return builder + "}";
}

109 110 111 112
std::string OpDef::name() const {
    return trait()->name;
}

113 114 115 116 117 118 119 120
size_t OpDef::hash() const {
    return trait()->hash(*this);
}

bool OpDef::is_same_st(const Hashable& rhs) const {
    return trait()->is_same_st(*this, static_cast<const OpDef&>(rhs));
}

121 122 123
const OpTrait* OpDef::trait() const {
    if (!m_trait) {
        m_trait = OpTrait::find_by_typeinfo(dyn_typeinfo());
M
Megvii Engine Team 已提交
124 125 126
        mgb_throw_if(
                !m_trait, MegBrainError, "can not find op_trait by %s",
                dyn_typeinfo()->name);
127 128 129 130
    }
    return m_trait;
}

131 132 133 134 135 136 137 138 139 140 141 142 143 144
const std::string OpDef::scope() const {
    return m_scope;
}

void OpDef::set_scope(const std::string& scope) {
    m_scope = scope;
}

const std::string OpDef::make_name() const {
    if (m_scope.empty())
        return trait()->make_name(*this);
    return m_scope + "." + trait()->make_name(*this);
}

145 146 147 148
const std::string OpDef::type_name() const {
    return trait()->name;
}

149 150 151 152 153 154 155 156 157 158 159
static thread_local OpDef::allocator_t local_allocator;

void OpDef::set_allocator(allocator_t allocator) {
    mgb_assert(!local_allocator, "allocator has been set before");
    local_allocator = allocator;
}

DeviceTensorStorage::RawStorage OpDef::allocate(CompNode device, size_t size) const {
    return local_allocator(device, size);
}

160 161 162 163
std::string Subgraph::repr() const {
    std::ostringstream buf;
    buf << "(";
    for (size_t i = 0; i < inputs.size(); ++i) {
M
Megvii Engine Team 已提交
164 165
        if (i > 0)
            buf << ", ";
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
        buf << "%" << inputs[i];
    }
    buf << ") => {\n";
    auto fmt_const = [](size_t i, const TensorPtr& t) {
        if (t->shape().ndim == 1 && t->shape()[0] == 1) {
            auto&& v = t->get_value();
            if (v.dtype() == dtype::Float32{}) {
                return std::to_string(*v.ptr<dt_float32>());
            } else if (v.dtype() == dtype::Int32{}) {
                return std::to_string(*v.ptr<int32_t>());
            }
        }
        return std::string("%c") + std::to_string(i);
    };
    std::unordered_map<size_t, std::string> const_reps;
    for (auto&& [i, t] : constants) {
        const_reps.emplace(i, fmt_const(i, t));
    }
    for (auto& [op, ins, outs] : exprs) {
        buf << "  ";
        if (outs.size()) {
            for (size_t i = 0; i < outs.size(); ++i) {
M
Megvii Engine Team 已提交
188 189
                if (i > 0)
                    buf << ", ";
190 191 192 193 194 195 196
                buf << "%" << outs[i];
            }
            buf << " = ";
        }
        if (auto* p = op->try_cast_final<OprAttr>()) {
            buf << p->type;
        } else {
197
            buf << op->to_string();
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
        }
        for (size_t i : ins) {
            buf << " ";
            auto&& it = const_reps.find(i);
            if (it != const_reps.end()) {
                buf << it->second;
            } else {
                buf << "%" << i;
            }
        }
        buf << "\n";
    }
    buf << "  ";
    if (outputs.size()) {
        for (size_t i = 0; i < outputs.size(); ++i) {
M
Megvii Engine Team 已提交
213 214
            if (i > 0)
                buf << ", ";
215 216 217 218 219 220 221 222 223
            buf << "%" << outputs[i];
        }
    } else {
        buf << "()";
    }
    buf << "\n}\n";
    return buf.str();
}

224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
bool Subgraph::is_single() const {
    if (exprs.size() != 1) {
        return false;
    }
    auto& expr = exprs.at(0);
    return expr.inputs == inputs && expr.outputs == outputs;
}

std::shared_ptr<OpDef> Subgraph::as_single() const {
    if (is_single()) {
        return exprs.at(0).op;
    } else {
        return nullptr;
    }
}

bool Subgraph::operator==(const Subgraph& rhs) const {
    mgb_assert(false, "Not Implemented");
}

M
Megvii Engine Team 已提交
244 245
}  // namespace imperative
}  // namespace mgb
246 247

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}