From b4581788472280a04b758aeeadc4ce72ba7c253c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 26 Sep 2021 19:32:35 +0800 Subject: [PATCH] feat(opr): add mutable tensor opr GitOrigin-RevId: 7f8a3d7b661c18fb407047a78b965630b52e61d9 --- imperative/src/impl/opr_utility.cpp | 63 +++++++++++++++++++ .../include/megbrain/imperative/opr_utility.h | 23 +++++++ 2 files changed, 86 insertions(+) diff --git a/imperative/src/impl/opr_utility.cpp b/imperative/src/impl/opr_utility.cpp index c2550960f..3b61ded3a 100644 --- a/imperative/src/impl/opr_utility.cpp +++ b/imperative/src/impl/opr_utility.cpp @@ -271,6 +271,69 @@ void NopCallback::do_execute(ExecEnv& env) { env.dispatch_on_comp_node(cn, runner); } +MGB_DYN_TYPE_OBJ_FINAL_IMPL(MutableTensor); +MutableTensor::MutableTensor( + cg::ComputingGraph& graph, std::shared_ptr dev_tensor, + std::shared_ptr host_tensor, const OperatorNodeConfig& config) + : Super(&graph, config, {}, {}) { + m_dev_tensor = dev_tensor; + m_host_tensor = host_tensor; + + add_output(None) + ->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) + .add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) + .dtype(m_dev_tensor->dtype()); + add_equivalence_component>(this); +} + +SymbolVar MutableTensor::make( + cg::ComputingGraph& graph, std::shared_ptr dev_tensor, + std::shared_ptr host_tensor, const OperatorNodeConfig& config) { + return graph + .insert_opr(std::make_unique( + graph, dev_tensor, host_tensor, config)) + ->output(0); +} + +void MutableTensor::init_output_comp_node() { + if (config().has_comp_node_set()) { + mgb_assert( + config().get_single_comp_node() == m_dev_tensor->comp_node(), + "comp_node mismatch"); + } + comp_node(m_dev_tensor->comp_node()); +} + +cg::OperatorNodeBase::NodeProp* MutableTensor::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_flag(NodeProp::Flag::IMPURE_OUTPUT_MEM_PLAN); + return ret; +} + +void MutableTensor::scn_do_execute() { + output(0)->reset_dev_tensor_from_tensor(*m_dev_tensor); +} + +void MutableTensor::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto& mgr = owner_graph()->static_infer_manager(); + auto infer_shape = [this](TensorShape& dest, const InpVal&) { + dest = m_dev_tensor->shape(); + return true; + }; + mgr.register_shape_infer(output(0), {SourceType::MUTABLE, {}, infer_shape}); + if (m_host_tensor) { + auto infer_value = [this](DeviceTensorND& dest, const InpVal&) { + if (!m_host_tensor->layout().ndim) { + return false; + } + dest = m_host_tensor->proxy_to_default_cpu(); + return true; + }; + mgr.register_value_infer(output(0), {SourceType::MUTABLE, {}, infer_value}); + } +} + } // namespace opr } // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/opr_utility.h b/imperative/src/include/megbrain/imperative/opr_utility.h index 589783538..3edc94504 100644 --- a/imperative/src/include/megbrain/imperative/opr_utility.h +++ b/imperative/src/include/megbrain/imperative/opr_utility.h @@ -16,6 +16,7 @@ #include "megbrain/opr/internal/identical_fwd.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/internal/param_tag_defs.h" +#include "megbrain/opr/io.h" #include "megbrain/opr/param_defs.h" #include "megbrain/serialization/sereg.h" @@ -106,6 +107,28 @@ protected: private: callback_t m_callback; }; + +MGB_DEFINE_OPR_CLASS(MutableTensor, cg::SingleCNOperatorNodeBase) // { +public: + MutableTensor( + cg::ComputingGraph& graph, std::shared_ptr dev_tensor, + std::shared_ptr host_tensor, + const OperatorNodeConfig& config); + static SymbolVar make( + cg::ComputingGraph& graph, std::shared_ptr dev_tensor, + std::shared_ptr host_tensor = {}, + const OperatorNodeConfig& config = {}); + +protected: + void init_output_comp_node() override; + void init_output_static_infer_desc() override; + cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; + void scn_do_execute() override; + +private: + std::shared_ptr m_dev_tensor; + std::shared_ptr m_host_tensor; +}; } // namespace opr } // namespace mgb -- GitLab