diff --git a/src/opr/impl/custom_opnode.cpp b/src/opr/impl/custom_opnode.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3b7931d2f0db70db6daee18ca949199065a73bb9 --- /dev/null +++ b/src/opr/impl/custom_opnode.cpp @@ -0,0 +1,329 @@ +/** + * \file src/opr/impl/custom_opnode.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/opr/custom_opnode.h" + +namespace mgb { +namespace opr { + +MGB_DYN_TYPE_OBJ_FINAL_IMPL(CustomOpNode); + +void CustomOpNode::infer_output_comp_node(void) { + SmallVector input_comp_nodes(input_num()); + for (int i=0; icomp_node(); + } + + SmallVector output_comp_nodes = custom::to_builtin( + m_op->infer_output_device( + custom::to_custom(input_comp_nodes), m_param + ) + ); + + for (int i=0; icomp_node(output_comp_nodes[i]); + } + + m_comp_node = output_comp_nodes[0]; +} + +void CustomOpNode::infer_output_dtype(void) { + SmallVector input_dtypes(input_num()); + for (int i=0; idtype(); + } + + SmallVector output_dtypes = custom::to_builtin( + m_op->infer_output_dtype( + custom::to_custom(input_dtypes), m_param + ) + ); + + for (int i=0; idtype(output_dtypes[i]); + } +} + +void CustomOpNode::infer_output_format(void) { + SmallVector input_formats(input_num()); + for (int i=0; iformat(); + } + + SmallVector output_formats = custom::to_builtin( + m_op->infer_output_format( + custom::to_custom(input_formats), m_param + ) + ); + + for (int i=0; iformat(output_formats[i]); + } +} + +void CustomOpNode::infer_output_shape(void) { + SmallVector input_shapes(input_num()); + for (int i=0; ishape(); + } + + SmallVector output_shapes = custom::to_builtin( + m_op->infer_output_shape( + custom::to_custom(input_shapes), m_param + ) + ); + + for (int i=0; ishape(output_shapes[i]); + } +} + +void CustomOpNode::infer_output_shape(const TensorShapeArray &input_shapes, + TensorShapeArray &output_shapes) { + output_shapes = custom::to_builtin( + m_op->infer_output_shape( + custom::to_custom(input_shapes), m_param + ) + ); +} + +// called by computing_graph for each output varnode +bool CustomOpNode::infer_desc(size_t out_idx, TensorShape &output_shape, + const StaticInferInpVal &input_vals) { + TensorShapeArray input_shapes(input_vals.val.size()); + TensorShapeArray output_shapes(output_num()); + + for (size_t i = 0; i < input_shapes.size(); ++ i) { + input_shapes[i] = input_vals.val[i].shape(); + } + + infer_output_shape(input_shapes, output_shapes); + output_shape = output_shapes.at(out_idx); + return true; +} + +void CustomOpNode::init_output_dtype() { + infer_output_dtype(); +} + +void CustomOpNode::init_output_format() { + infer_output_format(); +} + +void CustomOpNode::init_output_comp_node() { + infer_output_comp_node(); +} + +void CustomOpNode::do_execute(ExecEnv &env) { + auto runner = [this]() { + this->owner_graph()->event().signal_inplace( + this, m_comp_node + ); + m_comp_node.activate(); + + SmallVector inputs, outputs; + for(size_t i=0; idev_tensor()); + for(size_t i=0; idev_tensor()); + + std::vector custom_inputs = custom::to_custom(inputs); + std::vector custom_outputs = custom::to_custom(outputs); + m_op->compute(custom_inputs, m_param, custom_outputs); + CompNode::sync_all(); // whether reasonable + + this->owner_graph()->event().signal_inplace( + this, m_comp_node + ); + }; + env.dispatch_on_comp_node(m_comp_node, runner); +} + +void CustomOpNode::init_output_static_infer_desc() { + using namespace std::placeholders; + using namespace cg::static_infer; + + m_out_shape.resize(output_num()); + auto &&mgr = owner_graph()->static_infer_manager(); + + DepVal dep; + if (true) { // need design a function to allow user to decide it + for (auto input_var: input()) + dep.push_back({input_var, DepType::SHAPE}); + } + else { + for (auto input_var: input()) + dep.push_back({input_var, DepType::VALUE}); + } + + for (size_t i = 0; i < output_num(); ++ i) { + mgr.register_shape_infer(output(i), { + dep.empty() ? SourceType::CONSTANT : SourceType::DEP, dep, + std::bind(&CustomOpNode::infer_desc, this, i, _1, _2) + }); + } +} + +void CustomOpNode::init_output_mem_plan(bool dynamic) { + for (auto output_var: output()) { + if (cg::is_static_var_storage(output_var) == !dynamic + && !output_var->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)) + output_var->init_mem_plan(); + } +} + +void CustomOpNode::init_rt_force_dynamic_mem_alloc_imply_chain() { + +} + +void CustomOpNode::add_input_layout_constraint() { + for (auto &&input_var: input()) { + input_var->add_layout_constraint_contiguous(); + } +} + +void CustomOpNode::mem_plan_fwd_in2out_readonly() { + +} + +void CustomOpNode::mem_plan_fwd_in2out_writable() { + +} + +cg::OperatorNodeBase::OprEventCallback CustomOpNode::get_opr_event_callback() { + return {}; +} + +void CustomOpNode::on_output_comp_node_stream_changed() { + for (auto output_var: output()) { + if (output_var->comp_node() != m_comp_node) { + mgb_assert(output_var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)); + output_var->comp_node(m_comp_node); + } + } +} + +cg::OperatorNodeBase::NodeProp* CustomOpNode::do_make_node_prop() const { + // auto ret = &const_cast(node_prop()); + // for (auto &&inp_var: input()) + // ret->add_dep_type(inp_var, NodeProp::DepType::DEV_VALUE); + // ret->add_flag(NodeProp::Flag::SINGLE_COMP_NODE); + // return ret; + return OperatorNodeBase::do_make_node_prop(); +} + +bool CustomOpNode::update_priority() const { + if (output_num() == 1 + && output()[0]->contain_flag(VarNode::Flag::PERSISTENT_DEVICE_VALUE)) { + node_prop().attribute().priority + = std::numeric_limits::min(); + return true; + } + return false; +} + +CustomOpNode::CustomOpNode(const std::shared_ptr &op, + VarNodeArray inputs, + const custom::Param ¶m, + const OperatorNodeConfig &config): + OperatorNodeBase(inputs[0]->owner_graph(), config, op->op_type(), inputs), m_op(op), m_param(param) { + mgb_assert(input_num() == inputs.size(), "wrong input tensors list length"); + for (int i=0; i < input_num(); ++i) + add_input({inputs[i]}); + + for (int i=0; i::value) { + using step = unsigned long; + size_t STEP_SIZE = sizeof(step); + std::string hash_str = std::to_string(op->runtime_id()); + for (auto &&val: param.raw()) { + hash_str += val.first; + hash_str += val.second.str(); + } + if (hash_str.size() % STEP_SIZE != 0) + hash_str += std::string(STEP_SIZE - (hash_str.size() % STEP_SIZE), ' '); + for (size_t pos=0; pos >(reinterpret_cast(hash_str.c_str()+pos)); + } +} + +VarNodeArray CustomOpNode::make(const std::shared_ptr &op, + VarNodeArray inputs, + const custom::Param ¶m, + const OperatorNodeConfig &config) { + auto &&outputs = inputs[0]->owner_graph()->insert_opr( + std::make_unique(op, inputs, param, config))->output(); + return outputs; +} + +SymbolVarArray CustomOpNode::make(const std::shared_ptr &op, + SymbolVarArray inputs, + const custom::Param ¶m, + const OperatorNodeConfig &config) { + VarNodeArray input_vars(inputs.size()); + for (size_t i=0; iowner_graph()->insert_opr( + std::make_unique(op, input_vars, param, config))->output(); + SymbolVarArray ret(outputs.size()); + for (size_t i=0; iruntime_id(); +} + +uint32_t CustomOpNode::param_tag(void) const { + return m_op->param_info().tag(); +} + +custom::Param& CustomOpNode::param(void) { + return m_param; +} + +custom::Param CustomOpNode::param(void) const { + return m_param; +} + +// a series of functions with the same names as CustomOpImpl +std::string CustomOpNode::op_type(void) const { + return m_op->op_type(); +} + +std::string CustomOpNode::op_desc(void) const { + return m_op->op_desc(); +} + +int CustomOpNode::input_num(void) const { + return m_op->input_num(); +} + +int CustomOpNode::output_num(void) const { + return m_op->output_num(); +} + +custom::ArgInfo CustomOpNode::input_info(size_t idx) const { + return m_op->input_info(idx); +} + +custom::ArgInfo CustomOpNode::output_info(size_t idx) const { + return m_op->output_info(idx); +} + +} +} diff --git a/src/opr/impl/custom_opnode.sereg.h b/src/opr/impl/custom_opnode.sereg.h new file mode 100644 index 0000000000000000000000000000000000000000..b4f429ce1dbd2708a6431de696efe85d376796f7 --- /dev/null +++ b/src/opr/impl/custom_opnode.sereg.h @@ -0,0 +1,70 @@ +/** + * \file src/opr/impl/custom_opnode.sereg.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#include "megbrain/opr/custom_opnode.h" +#include "megbrain/serialization/sereg.h" + +namespace mgb { +namespace serialization { + +void custom_dumper(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { + auto &&custom_op = opr.cast_final_safe(); + + std::string op_type = custom_op.op_type(); + ctx.dump_buf_with_len(op_type.c_str(), op_type.size()); + + uint32_t tag = custom_op.param_tag(); + ctx.dump_buf_with_len(&tag, sizeof(tag)); + + std::string bytes = custom_op.param().to_bytes(); + ctx.dump_buf_with_len(bytes.c_str(), bytes.size()); +} + +mgb::cg::OperatorNodeBase *custom_loader(OprLoadContext& ctx, + const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + std::string op_type = ctx.load_buf_with_len(); + auto *op_manager = custom::CustomOpManager::inst(); + auto op = op_manager->find(op_type); + + std::string tag_str = ctx.load_buf_with_len(); + uint32_t tag = *reinterpret_cast(tag_str.c_str()); + mgb_assert( + tag == op->param_info().tag(), + "Wrong Param TAG of Op %s, should be %u, but load %u\n", + op_type.c_str(), op->param_info().tag(), tag + ); + + custom::Param param(op->param_info()); + std::string bytes = ctx.load_buf_with_len(); + param.from_bytes(bytes); + return opr::CustomOpNode::make(op, inputs, param, config)[0]->owner_opr(); +} + +} +} + +#define CUSTOM_OP_SEREG_REG(cls) \ + namespace { \ + struct _OprReg##cls { \ + static void entry() { \ + MGB_SEREG_OPR_INTL_CALL_ADD( \ + cls, \ + ::mgb::serialization::custom_dumper, \ + ::mgb::serialization::custom_loader); \ + } \ + }; \ + } \ + MGB_SEREG_OPR_INTL_CALL_ENTRY(cls, _OprReg##cls) + +using namespace mgb; +using CustomOpNode = opr::CustomOpNode; +CUSTOM_OP_SEREG_REG(CustomOpNode); diff --git a/src/opr/include/megbrain/opr/custom_opnode.h b/src/opr/include/megbrain/opr/custom_opnode.h new file mode 100644 index 0000000000000000000000000000000000000000..2480d80161caca03242c06319c4ca1d0da4db138 --- /dev/null +++ b/src/opr/include/megbrain/opr/custom_opnode.h @@ -0,0 +1,103 @@ +/** + * \file src/opr/include/megbrain/opr/custom_opnode.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/custom/custom.h" +#include "megbrain/custom/manager.h" +#include "megbrain/custom/data_adaptor.h" +#include "megbrain/graph/operator_node.h" +#include "megbrain/graph/symbol_var.h" +#include "megbrain/graph/helper.h" +#include "megbrain/graph/event.h" +#include "megbrain/serialization/sereg.h" + +namespace mgb { +namespace opr { + +using VarNode = cg::VarNode; +using VarNodeArray = cg::VarNodeArray; +using SymbolVar = cg::SymbolVar; +using SymbolVarArray = cg::SymbolVarArray; +using StaticInferInpVal = cg::StaticInferInpVal; +using OperatorNodeConfig = cg::OperatorNodeConfig; + +MGB_DEFINE_OPR_CLASS(CustomOpNode, cg::OperatorNodeBase) // { + const std::shared_ptr m_op; + custom::Param m_param; + CompNode m_comp_node; + TensorShapeArray m_out_shape; + + void infer_output_comp_node(void); + void infer_output_dtype(void); + void infer_output_format(void); + void infer_output_shape(void); + void infer_output_shape(const TensorShapeArray &input_shapes, TensorShapeArray &output_shapes); + + // called by computing_graph for each output varnode + bool infer_desc(size_t out_idx, TensorShape &output_shape, const StaticInferInpVal &input_vals); + + void init_output_dtype() override final; + void init_output_format() override final; + void init_output_comp_node() override final; + void do_execute(ExecEnv &env) override final; + void init_output_static_infer_desc() override final; + void init_output_mem_plan(bool dynamic) override final; + + // [TODO] if some dynamic mem alloc flag in m_opimpl, ignore it for now + void init_rt_force_dynamic_mem_alloc_imply_chain() override final; + + // [TODO] only contiguous input is supported + void add_input_layout_constraint() override final; + + // [TODO] ignore it for now + void mem_plan_fwd_in2out_readonly() override final; + + // [TODO] ignore it for now + void mem_plan_fwd_in2out_writable() override final; + + // [TODO] return default ctor obj + OprEventCallback get_opr_event_callback() override final; + + // [TODO] + void on_output_comp_node_stream_changed() override final; + + // [TODO] + NodeProp* do_make_node_prop() const override final; + + // [TODO] default implementation + bool update_priority() const override final; + +public: + CustomOpNode(const std::shared_ptr &op, + VarNodeArray inputs, const custom::Param ¶m, + const OperatorNodeConfig &config); + static VarNodeArray make(const std::shared_ptr &op, + VarNodeArray inputs, const custom::Param ¶m, + const OperatorNodeConfig &config); + static SymbolVarArray make(const std::shared_ptr &op, + SymbolVarArray inputs, const custom::Param ¶m, + const OperatorNodeConfig &config); + + custom::RunTimeId runtime_id(void) const; + uint32_t param_tag(void) const; + custom::Param& param(void); + custom::Param param(void) const; + std::string op_type(void) const; + std::string op_desc(void) const; + int input_num(void) const; + int output_num(void) const; + custom::ArgInfo input_info(size_t idx) const; + custom::ArgInfo output_info(size_t idx) const; +}; + +} // namespace opr +} diff --git a/src/serialization/impl/sereg_caller.cpp b/src/serialization/impl/sereg_caller.cpp index dc61c1babbd671e22c78ea0776c80d7c48d794b7..08a5c96ca68dcd2b314ca770ee8387411a9e2809 100644 --- a/src/serialization/impl/sereg_caller.cpp +++ b/src/serialization/impl/sereg_caller.cpp @@ -29,6 +29,7 @@ namespace mgb{void call_sereg(){}} #include "../../opr/impl/tensor_gen.sereg.h" #include "../../opr/impl/tensor_manip.sereg.h" #include "../../opr/impl/utility.sereg.h" +#include "../../opr/impl/custom_opnode.sereg.h" #if MGB_ENABLE_TENSOR_RT #include "../../tensorrt/impl/tensorrt_opr.sereg.h" #endif