custom_opnode.cpp 10.1 KB
Newer Older
1 2
#include "megbrain/opr/custom_opnode.h"

3 4
#if MGB_CUSTOM_OP

5 6 7 8 9 10 11
namespace mgb {
namespace opr {

MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomOpNode);

void CustomOpNode::infer_output_comp_node(void) {
    SmallVector<CompNode> input_comp_nodes(input_num());
M
Megvii Engine Team 已提交
12
    for (size_t i = 0; i < input_num(); ++i) {
13 14 15
        input_comp_nodes[i] = input(i)->comp_node();
    }

M
Megvii Engine Team 已提交
16 17 18 19 20 21 22 23 24
    SmallVector<CompNode> output_comp_nodes =
            custom::to_builtin<CompNode, custom::Device>(m_op->infer_output_device(
                    custom::to_custom<CompNode, custom::Device>(input_comp_nodes),
                    m_param));

    for (size_t i = 0; i < output_num(); ++i) {
        mgb_assert(
                output_comp_nodes[i] == output_comp_nodes[0],
                "only single comp node operator is supported");
25 26 27 28 29 30 31 32
        output(i)->comp_node(output_comp_nodes[i]);
    }

    m_comp_node = output_comp_nodes[0];
}

void CustomOpNode::infer_output_dtype(void) {
    SmallVector<DType> input_dtypes(input_num());
M
Megvii Engine Team 已提交
33
    for (size_t i = 0; i < input_num(); ++i) {
34 35 36
        input_dtypes[i] = input(i)->dtype();
    }

M
Megvii Engine Team 已提交
37 38 39 40
    SmallVector<DType> output_dtypes =
            custom::to_builtin<megdnn::DType, custom::DType>(m_op->infer_output_dtype(
                    custom::to_custom<megdnn::DType, custom::DType>(input_dtypes),
                    m_param));
41

M
Megvii Engine Team 已提交
42
    for (size_t i = 0; i < output_num(); ++i) {
43 44 45 46 47 48
        output(i)->dtype(output_dtypes[i]);
    }
}

void CustomOpNode::infer_output_format(void) {
    SmallVector<TensorFormat> input_formats(input_num());
M
Megvii Engine Team 已提交
49
    for (size_t i = 0; i < input_num(); ++i) {
50 51 52
        input_formats[i] = input(i)->format();
    }

M
Megvii Engine Team 已提交
53 54 55 56
    SmallVector<TensorFormat> output_formats =
            custom::to_builtin<TensorFormat, custom::Format>(m_op->infer_output_format(
                    custom::to_custom<TensorFormat, custom::Format>(input_formats),
                    m_param));
57

M
Megvii Engine Team 已提交
58
    for (size_t i = 0; i < output_num(); ++i) {
59 60 61 62 63 64
        output(i)->format(output_formats[i]);
    }
}

void CustomOpNode::infer_output_shape(void) {
    SmallVector<TensorShape> input_shapes(input_num());
M
Megvii Engine Team 已提交
65
    for (size_t i = 0; i < input_num(); ++i) {
66 67 68
        input_shapes[i] = input(i)->shape();
    }

M
Megvii Engine Team 已提交
69 70 71 72
    SmallVector<TensorShape> output_shapes =
            custom::to_builtin<TensorShape, custom::Shape>(m_op->infer_output_shape(
                    custom::to_custom<TensorShape, custom::Shape>(input_shapes),
                    m_param));
73

M
Megvii Engine Team 已提交
74
    for (size_t i = 0; i < output_num(); ++i) {
75 76 77 78
        output(i)->shape(output_shapes[i]);
    }
}

M
Megvii Engine Team 已提交
79 80 81 82 83 84
void CustomOpNode::infer_output_shape(
        const TensorShapeArray& input_shapes, TensorShapeArray& output_shapes) {
    output_shapes =
            custom::to_builtin<TensorShape, custom::Shape>(m_op->infer_output_shape(
                    custom::to_custom<TensorShape, custom::Shape>(input_shapes),
                    m_param));
85 86 87
}

// called by computing_graph for each output varnode
M
Megvii Engine Team 已提交
88 89 90
bool CustomOpNode::infer_desc(
        size_t out_idx, TensorShape& output_shape,
        const StaticInferInpVal& input_vals) {
91 92 93
    TensorShapeArray input_shapes(input_vals.val.size());
    TensorShapeArray output_shapes(output_num());

M
Megvii Engine Team 已提交
94
    for (size_t i = 0; i < input_shapes.size(); ++i) {
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        input_shapes[i] = input_vals.val[i].shape();
    }

    infer_output_shape(input_shapes, output_shapes);
    output_shape = output_shapes.at(out_idx);
    return true;
}

void CustomOpNode::init_output_dtype() {
    infer_output_dtype();
}

void CustomOpNode::init_output_format() {
    infer_output_format();
}

void CustomOpNode::init_output_comp_node() {
    infer_output_comp_node();
}

M
Megvii Engine Team 已提交
115
void CustomOpNode::do_execute(ExecEnv& env) {
116 117
    auto runner = [this]() {
        this->owner_graph()->event().signal_inplace<cg::event::BeforeKernel>(
M
Megvii Engine Team 已提交
118
                this, m_comp_node);
119 120 121
        m_comp_node.activate();

        SmallVector<DeviceTensorND> inputs, outputs;
M
Megvii Engine Team 已提交
122
        for (size_t i = 0; i < input_num(); i++)
123
            inputs.push_back(input(i)->dev_tensor());
M
Megvii Engine Team 已提交
124
        for (size_t i = 0; i < output_num(); i++)
125 126
            outputs.push_back(output(i)->dev_tensor());

M
Megvii Engine Team 已提交
127 128 129 130
        std::vector<custom::Tensor> custom_inputs =
                custom::to_custom<DeviceTensorND, custom::Tensor>(inputs);
        std::vector<custom::Tensor> custom_outputs =
                custom::to_custom<DeviceTensorND, custom::Tensor>(outputs);
131
        m_op->compute(custom_inputs, m_param, custom_outputs);
132 133
        // [TODO] sync should be modified
        CompNode::sync_all();
134 135

        this->owner_graph()->event().signal_inplace<cg::event::AfterKernel>(
M
Megvii Engine Team 已提交
136
                this, m_comp_node);
137 138 139 140 141 142 143 144 145
    };
    env.dispatch_on_comp_node(m_comp_node, runner);
}

void CustomOpNode::init_output_static_infer_desc() {
    using namespace std::placeholders;
    using namespace cg::static_infer;

    m_out_shape.resize(output_num());
M
Megvii Engine Team 已提交
146
    auto&& mgr = owner_graph()->static_infer_manager();
147 148

    DepVal dep;
149
    // [TODO] need design a interface to allow user to decide it
M
Megvii Engine Team 已提交
150 151
    if (true) {
        for (auto input_var : input())
152
            dep.push_back({input_var, DepType::SHAPE});
M
Megvii Engine Team 已提交
153 154
    } else {
        for (auto input_var : input())
155 156 157
            dep.push_back({input_var, DepType::VALUE});
    }

M
Megvii Engine Team 已提交
158 159 160 161
    for (size_t i = 0; i < output_num(); ++i) {
        mgr.register_shape_infer(
                output(i), {dep.empty() ? SourceType::CONSTANT : SourceType::DEP, dep,
                            std::bind(&CustomOpNode::infer_desc, this, i, _1, _2)});
162 163 164 165
    }
}

void CustomOpNode::init_output_mem_plan(bool dynamic) {
M
Megvii Engine Team 已提交
166 167 168
    for (auto output_var : output()) {
        if (cg::is_static_var_storage(output_var) == !dynamic &&
            !output_var->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC))
169 170 171 172
            output_var->init_mem_plan();
    }
}

M
Megvii Engine Team 已提交
173
void CustomOpNode::init_rt_force_dynamic_mem_alloc_imply_chain() {}
174

M
Megvii Engine Team 已提交
175 176
void CustomOpNode::add_input_layout_constraint() {
    for (auto&& input_var : input()) {
177 178 179 180
        input_var->add_layout_constraint_contiguous();
    }
}

M
Megvii Engine Team 已提交
181
void CustomOpNode::mem_plan_fwd_in2out_readonly() {}
182

M
Megvii Engine Team 已提交
183
void CustomOpNode::mem_plan_fwd_in2out_writable() {}
184

M
Megvii Engine Team 已提交
185
cg::OperatorNodeBase::OprEventCallback CustomOpNode::get_opr_event_callback() {
186 187 188
    return {};
}

M
Megvii Engine Team 已提交
189 190
void CustomOpNode::on_output_comp_node_stream_changed() {
    for (auto output_var : output()) {
191 192 193 194 195 196 197
        if (output_var->comp_node() != m_comp_node) {
            mgb_assert(output_var->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
            output_var->comp_node(m_comp_node);
        }
    }
}

M
Megvii Engine Team 已提交
198
cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const {
199 200 201 202
    return OperatorNodeBase::do_make_node_prop();
}

bool CustomOpNode::update_priority() const {
M
Megvii Engine Team 已提交
203 204 205 206
    if (output_num() == 1 &&
        output()[0]->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE)) {
        node_prop().attribute().priority =
                std::numeric_limits<decltype(NodeProp::Attribute::priority)>::min();
207 208 209 210 211
        return true;
    }
    return false;
}

M
Megvii Engine Team 已提交
212 213 214 215 216 217
CustomOpNode::CustomOpNode(
        const std::shared_ptr<const custom::CustomOp>& op, VarNodeArray inputs,
        const custom::Param& param, const OperatorNodeConfig& config)
        : OperatorNodeBase(inputs[0]->owner_graph(), config, op->op_type(), inputs),
          m_op(op),
          m_param(param) {
218
    mgb_assert(input_num() == inputs.size(), "wrong input tensors list length");
M
Megvii Engine Team 已提交
219
    for (size_t i = 0; i < input_num(); ++i)
220 221
        add_input({inputs[i]});

M
Megvii Engine Team 已提交
222
    for (size_t i = 0; i < output_num(); ++i)
223
        add_output(output_info(i).name());
M
Megvii Engine Team 已提交
224

225 226 227 228
    if (!std::is_empty<custom::Param>::value) {
        using step = unsigned long;
        size_t STEP_SIZE = sizeof(step);
        std::string hash_str = std::to_string(op->runtime_id());
M
Megvii Engine Team 已提交
229
        for (auto&& val : param.raw()) {
230 231 232 233 234
            hash_str += val.first;
            hash_str += val.second.str();
        }
        if (hash_str.size() % STEP_SIZE != 0)
            hash_str += std::string(STEP_SIZE - (hash_str.size() % STEP_SIZE), ' ');
M
Megvii Engine Team 已提交
235 236 237
        for (size_t pos = 0; pos < hash_str.size(); pos += STEP_SIZE)
            add_equivalence_component<PODHash<step>>(
                    reinterpret_cast<const step*>(hash_str.c_str() + pos));
238 239 240
    }
}

M
Megvii Engine Team 已提交
241 242 243 244 245 246 247 248
VarNodeArray CustomOpNode::make(
        const std::shared_ptr<const custom::CustomOp>& op, VarNodeArray inputs,
        const custom::Param& param, const OperatorNodeConfig& config) {
    auto&& outputs = inputs[0]
                             ->owner_graph()
                             ->insert_opr(std::make_unique<CustomOpNode>(
                                     op, inputs, param, config))
                             ->output();
249 250 251
    return outputs;
}

M
Megvii Engine Team 已提交
252 253 254
SymbolVarArray CustomOpNode::make(
        const std::shared_ptr<const custom::CustomOp>& op, SymbolVarArray inputs,
        const custom::Param& param, const OperatorNodeConfig& config) {
255
    VarNodeArray input_vars(inputs.size());
M
Megvii Engine Team 已提交
256
    for (size_t i = 0; i < input_vars.size(); ++i)
257 258
        input_vars[i] = inputs[i].node();

M
Megvii Engine Team 已提交
259 260 261 262 263 264
    auto&& outputs = inputs[0]
                             .node()
                             ->owner_graph()
                             ->insert_opr(std::make_unique<CustomOpNode>(
                                     op, input_vars, param, config))
                             ->output();
265
    SymbolVarArray ret(outputs.size());
M
Megvii Engine Team 已提交
266
    for (size_t i = 0; i < ret.size(); ++i)
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
        ret[i] = outputs[i];
    return ret;
}

custom::RunTimeId CustomOpNode::runtime_id() const {
    return m_op->runtime_id();
}

uint32_t CustomOpNode::param_tag(void) const {
    return m_op->param_info().tag();
}

custom::Param& CustomOpNode::param(void) {
    return m_param;
}

custom::Param CustomOpNode::param(void) const {
    return m_param;
}

// a series of functions with the same names as CustomOpImpl
std::string CustomOpNode::op_type(void) const {
    return m_op->op_type();
}

std::string CustomOpNode::op_desc(void) const {
    return m_op->op_desc();
}

296
size_t CustomOpNode::input_num(void) const {
297 298 299
    return m_op->input_num();
}

300
size_t CustomOpNode::output_num(void) const {
301 302 303 304 305 306 307 308 309 310 311
    return m_op->output_num();
}

custom::ArgInfo CustomOpNode::input_info(size_t idx) const {
    return m_op->input_info(idx);
}

custom::ArgInfo CustomOpNode::output_info(size_t idx) const {
    return m_op->output_info(idx);
}

M
Megvii Engine Team 已提交
312 313
}  // namespace opr
}  // namespace mgb
314 315

#endif