opr_attr.cpp 6.5 KB
Newer Older
1
#include "megbrain/imperative/ops/opr_attr.h"
2 3 4 5
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/rdnn/profiler.h"
6 7 8
#include "megbrain/serialization/opr_load_dump.h"

#include "../op_trait.h"
9
#include "megbrain/imperative/proxy_graph_detail.h"
10 11 12 13 14

namespace mgb {
namespace imperative {

namespace {
M
Megvii Engine Team 已提交
15
class OprParamsLoadContext final : public serialization::OprLoadContextRawPOD {
16 17
    const OprAttr::Param& m_param;
    size_t m_pos = 0;
M
Megvii Engine Team 已提交
18
    ComputingGraph* m_graph;
19

M
Megvii Engine Team 已提交
20
    void read_raw(void* dest, size_t size) override final {
21 22 23 24 25
        mgb_assert(m_pos + size <= m_param.size(), "too many bytes requested");
        memcpy(dest, m_param.data() + m_pos, size);
        m_pos += size;
    }

M
Megvii Engine Team 已提交
26
    std::shared_ptr<HostTensorND> load_tensor() override { mgb_assert(0); }
27

M
Megvii Engine Team 已提交
28
    std::shared_ptr<DeviceTensorND> load_tensor_shared() override { mgb_assert(0); }
29

M
Megvii Engine Team 已提交
30
    const serialization::GraphLoadConfig& config() const override { mgb_assert(0); }
31

M
Megvii Engine Team 已提交
32 33 34 35 36
public:
    OprParamsLoadContext(const OprAttr::Param& param, ComputingGraph* graph)
            : serialization::OprLoadContextRawPOD(false),
              m_param(param),
              m_graph(graph) {}
37

M
Megvii Engine Team 已提交
38 39 40
    ~OprParamsLoadContext() {
        mgb_assert(m_pos == m_param.size(), "param not fully consumed");
    }
41

M
Megvii Engine Team 已提交
42
    ComputingGraph& graph() override { return *m_graph; }
43 44
};

M
Megvii Engine Team 已提交
45
class OprParamsDumpContext final : public serialization::OprDumpContextRawPOD {
46 47 48
public:
    OprAttr::Param m_param;
    OprParamsDumpContext() : serialization::OprDumpContextRawPOD(false) {}
M
Megvii Engine Team 已提交
49
    void write_raw(const void* data, size_t size) {
50 51 52 53
        const char* src = static_cast<const char*>(data);
        m_param.insert(m_param.end(), src, src + size);
    }
    void dump_tensor(
M
Megvii Engine Team 已提交
54
            const std::string& name, const HostTensorND& tensor,
55 56 57
            TensorWriteMethod method) {
        mgb_assert(0);
    }
M
Megvii Engine Team 已提交
58
    const serialization::GraphDumpConfig& config() const { mgb_assert(0); }
59 60
};

61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
#define cb(FASTRUN_OPR)                                                           \
    megdnn::param::ExecutionPolicy get_strategy_##FASTRUN_OPR(                    \
            cg::OperatorNodeBase* opr) {                                          \
        auto policy =                                                             \
                opr->cast_final<opr::FASTRUN_OPR>().execution_policy_transient(); \
        return policy;                                                            \
    }                                                                             \
    void set_strategy_##FASTRUN_OPR(                                              \
            cg::OperatorNodeBase* opr, megdnn::param::ExecutionPolicy policy) {   \
        auto&& p = opr->cast_final<opr::FASTRUN_OPR>();                           \
        p.set_execution_policy(policy);                                           \
    }

DNN_FOREACH_FASTRUN_OPR(cb)
#undef cb

typedef thin_function<megdnn::param::ExecutionPolicy(cg::OperatorNodeBase*)> get_func;
typedef thin_function<void(cg::OperatorNodeBase*, megdnn::param::ExecutionPolicy)>
        set_func;

static const mgb::thin_hash_table::ThinHashMap<
        mgb::Typeinfo*, std::pair<get_func, set_func>>&
get_type2policy() {
    static mgb::thin_hash_table::ThinHashMap<
            mgb::Typeinfo*, std::pair<get_func, set_func>>
            sl_type2policy;
    static std::once_flag flag;
    std::call_once(flag, [&]() {
#define cb(FASTRUN_OPR)                            \
    sl_type2policy[opr::FASTRUN_OPR::typeinfo()] = \
            std::make_pair(get_strategy_##FASTRUN_OPR, set_strategy_##FASTRUN_OPR);
        DNN_FOREACH_FASTRUN_OPR(cb)
    });
    return std::as_const(sl_type2policy);
}

M
Megvii Engine Team 已提交
97
VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
98
    auto&& attr = def.cast_final_safe<OprAttr>();
99 100
    auto config = attr.config;
    config.name(attr.make_name());
101 102 103 104
    mgb_assert(!inputs.empty());
    auto registry = serialization::OprRegistry::find_by_name(attr.type);
    mgb_assert(registry, "operator %s not found", attr.type.c_str());
    OprParamsLoadContext ctx{attr.param, inputs[0]->owner_graph()};
105 106 107 108 109 110
    auto opr_with_accessor = registry->loader(ctx, inputs, config);
    auto&& opr = opr_with_accessor.opr();
    if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) {
        get_type2policy().at(opr->dyn_typeinfo()).second(opr, attr.policy);
    }
    return opr_with_accessor.usable_output();
111 112 113 114 115 116
}

std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* opr) {
    OprParamsDumpContext ctx;
    auto registry = serialization::OprRegistry::find_by_type(opr->dyn_typeinfo());
    mgb_assert(registry, "operator %s not found", opr->dyn_typeinfo()->name);
M
Megvii Engine Team 已提交
117 118 119
    mgb_assert(
            registry->dumper, "operator %s cannot be serialized",
            opr->dyn_typeinfo()->name);
120
    registry->dumper(ctx, *opr);
121 122 123 124 125
    megdnn::param::ExecutionPolicy policy;
    if (get_type2policy().find(opr->dyn_typeinfo()) != get_type2policy().end()) {
        policy = get_type2policy().at(opr->dyn_typeinfo()).first(opr);
    }
    return OprAttr::make(registry->name, std::move(ctx.m_param), policy, opr->config());
126 127
}

128 129 130 131
std::vector<std::pair<const char*, std::string>> props(const OpDef& def) {
    return {};
}

132
std::string make_name(const OpDef& def) {
133 134
    auto&& attr = def.cast_final_safe<OprAttr>();
    return attr.type;
135 136
}

137
OP_TRAIT_REG(OprAttr, OprAttr)
M
Megvii Engine Team 已提交
138 139 140 141 142
        .make_from_op_node(make_from_op_node)
        .apply_on_var_node(apply_on_var_node)
        .props(props)
        .make_name(make_name)
        .fallback();
143

M
Megvii Engine Team 已提交
144
}  // anonymous namespace
145 146 147

bool OprAttr::is_same_st(const Hashable& rhs_) const {
    auto&& rhs = static_cast<const OprAttr&>(rhs_);
M
Megvii Engine Team 已提交
148
    return type == rhs.type && param == rhs.param &&
149 150
           policy.strategy == rhs.policy.strategy &&
           policy.workspace_limit == rhs.policy.workspace_limit &&
M
Megvii Engine Team 已提交
151 152
           config.comp_node() == rhs.config.comp_node() &&
           config.output_dtype() == rhs.config.output_dtype();
153 154 155 156 157
}

size_t OprAttr::hash() const {
    return hash_pair_combine(
            hash_pair_combine(
158 159 160 161 162 163
                    hash_pair_combine(
                            mgb::hash(type),
                            mgb::hash(static_cast<std::vector<char>>(param))),
                    hash_pair_combine(
                            static_cast<size_t>(policy.strategy),
                            policy.workspace_limit)),
164 165 166 167 168
            config.hash());
}

MGB_DYN_TYPE_OBJ_FINAL_IMPL(OprAttr);

M
Megvii Engine Team 已提交
169 170
}  // namespace imperative
}  // namespace mgb
171 172

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}