From 78d7d400d1979abdf9776ddaa60442b5cb91382c Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 8 Sep 2020 21:17:52 +0800 Subject: [PATCH] feat(opr): add a constant flavor for SharedDeviceTensor Also, add a CONSTANT value inference tag to outputs of MultipleDeviceTensorHolder. GitOrigin-RevId: 82a805ed5fed68376c0638a902e303f5a651a478 --- src/gopt/impl/inference.cpp | 4 +- src/opr/impl/io.cpp | 54 +++++++++++++++++---- src/opr/impl/io.sereg.h | 11 +++-- src/opr/include/megbrain/opr/io.h | 78 ++++++++++++++++++++++++++----- 4 files changed, 121 insertions(+), 26 deletions(-) diff --git a/src/gopt/impl/inference.cpp b/src/gopt/impl/inference.cpp index 88117d196..a7739d588 100644 --- a/src/gopt/impl/inference.cpp +++ b/src/gopt/impl/inference.cpp @@ -570,10 +570,10 @@ void ParamFusePass::apply(OptState &state) const { *var->owner_graph(), hv, var_namer.name(var)); } else { if (is_default_format) { - new_var = opr::SharedDeviceTensor::make( + new_var = opr::SharedDeviceTensor::make_const( *var->owner_graph(), inferred_val, var_namer.name(var)); } else { - new_var = opr::SharedDeviceTensorWithFormat::make( + new_var = opr::SharedDeviceTensorWithFormat::make_const( *var->owner_graph(), inferred_val, var_namer.name(var)); } } diff --git a/src/opr/impl/io.cpp b/src/opr/impl/io.cpp index 5e34884a3..279f26e2d 100644 --- a/src/opr/impl/io.cpp +++ b/src/opr/impl/io.cpp @@ -281,11 +281,11 @@ void Host2DeviceCopy::record_execute_deps(ExecDependencyArray& deps) { /* ===================== SharedDeviceTensor related ===================== */ intl::SharedDeviceTensorBase::SharedDeviceTensorBase( - ComputingGraph &graph, const std::shared_ptr &dev_data, - const OperatorNodeConfig &config): - Super{&graph, config, "shared", {}}, - m_dev_data{dev_data} -{ + ComputingGraph& graph, const std::shared_ptr& dev_data, + bool const_value, const OperatorNodeConfig& config) + : Super{&graph, config, "shared", {}}, + m_dev_data{dev_data}, + m_const_value(const_value) { if (config.has_comp_node_set()) { mgb_assert(config.get_single_comp_node() == dev_data->comp_node()); } @@ -307,26 +307,42 @@ void intl::SharedDeviceTensorBase::init_output_comp_node() { comp_node(m_dev_data->comp_node()); } +bool intl::SharedDeviceTensorBase::fill_in_static_infer(DeviceTensorND* dest) { + if (m_const_value) { + if (dest) { + if (m_static_infer.empty()) { + m_static_infer.comp_node(CompNode::default_cpu()) + .copy_from(*m_dev_data); + } + *dest = m_static_infer; + } + return true; + } + return false; +} + cg::static_infer::SourceType SharedDeviceTensor::static_infer_src_type() const { return cg::static_infer::SourceType::CONSTANT; } SymbolVar SharedDeviceTensor::make(ComputingGraph &graph, const std::shared_ptr &dev_data, + bool const_value, const OperatorNodeConfig &config) { return graph.insert_opr(std::make_unique( - graph, dev_data, config))->output(0); + graph, dev_data, const_value, config))->output(0); } SymbolVar SharedDeviceTensor::make(ComputingGraph &graph, const HostTensorND &value, + bool const_value, const OperatorNodeConfig &config) { auto cn = value.comp_node(); if (config.has_comp_node_set()) cn = config.get_single_comp_node(); auto dev_v = std::make_shared(); dev_v->comp_node(cn).copy_from(value).sync(); - return make(graph, dev_v, config); + return make(graph, dev_v, const_value, config); } MGB_DYN_TYPE_OBJ_FINAL_IMPL(SharedDeviceTensor); @@ -342,7 +358,7 @@ SymbolVar VolatileSharedDeviceTensor::make(ComputingGraph &graph, const std::shared_ptr &dev_data, const OperatorNodeConfig &config) { return graph.insert_opr(std::make_unique( - graph, dev_data, config))->output(0); + graph, dev_data, false, config))->output(0); } MGB_DYN_TYPE_OBJ_FINAL_IMPL(VolatileSharedDeviceTensor); @@ -354,10 +370,10 @@ void SharedDeviceTensorWithFormat::init_output_format() { SymbolVar SharedDeviceTensorWithFormat::make( ComputingGraph& graph, const std::shared_ptr& dev_data, - const OperatorNodeConfig& config) { + bool const_value, const OperatorNodeConfig& config) { auto&& opr = graph.insert_opr(std::make_unique( - graph, dev_data, config)) + graph, dev_data, const_value, config)) ->cast_final_safe(); return opr.output(0); } @@ -870,6 +886,24 @@ void intl::MultipleDeviceTensorHolderBase::init_output_static_infer_desc() { }; mgr.register_shape_infer(output(i), {SourceType::CONSTANT, {}, infer_shp}); + + auto infer_val = [this, i](DeviceTensorND& dest, const InpVal&) { + if (m_host_values.empty()) { + m_host_values.resize(m_values.size()); + } + if (m_host_values[i].empty()) { + m_host_values[i] + .comp_node(CompNode::default_cpu()) + .copy_from(*m_values[i]); + } + if (!m_host_values[i].empty()) { + dest = m_host_values[i]; + return true; + } + return false; + }; + mgr.register_value_infer(output(i), + {SourceType::CONSTANT, {}, infer_val}); } } diff --git a/src/opr/impl/io.sereg.h b/src/opr/impl/io.sereg.h index 32a0ce53a..f41898dbf 100644 --- a/src/opr/impl/io.sereg.h +++ b/src/opr/impl/io.sereg.h @@ -79,6 +79,10 @@ namespace serialization { HostTensorND val; val.copy_from(opr.get_dev_tensor()).sync(); ctx.dump_tensor(opr.name(), val, Meth::VALUE_SHARED); + // Note that we don't persist opr.m_const_value, because it does not + // affect correctness, and SharedDeviceTensor will be bundled + // together as MultipleDeviceTensorHolder in optimize_for_inference + // before being dumped. } static cg::OperatorNodeBase* load( @@ -280,9 +284,10 @@ namespace opr { const OperatorNodeConfig &config) { mgb_assert(inputs.empty()); auto &&opr = opr_.cast_final_safe(); - return Opr::make( - *ctx.owner_graph(opr, inputs), opr.dev_data(), config). - node()->owner_opr(); + return Opr::make(*ctx.owner_graph(opr, inputs), opr.dev_data(), + opr.const_value(), config) + .node() + ->owner_opr(); } cg::OperatorNodeBase* opr_shallow_copy_immutable_tensor( diff --git a/src/opr/include/megbrain/opr/io.h b/src/opr/include/megbrain/opr/io.h index 54a68e4d4..3be2c2515 100644 --- a/src/opr/include/megbrain/opr/io.h +++ b/src/opr/include/megbrain/opr/io.h @@ -75,19 +75,22 @@ class DeviceTensorHolder: public HostIONodeBase { */ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { std::shared_ptr m_dev_data; + DeviceTensorND m_static_infer; + bool m_const_value; const TensorShape& get_output_shape() override; - bool fill_in_static_infer(DeviceTensorND* dest) override { - MGB_MARK_USED_VAR(dest); - return false; - } + bool fill_in_static_infer(DeviceTensorND* dest) override; void init_output_comp_node() override; public: + //! const_value marks whether the device value of this operator should + //! be treated as constant during graph execution. Should be false in + //! most cases. SharedDeviceTensorBase(ComputingGraph &graph, const std::shared_ptr &dev_data, + bool const_value, const OperatorNodeConfig &config); const DeviceTensorND& get_dev_tensor() const override { @@ -97,6 +100,8 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { const std::shared_ptr& dev_data() const { return m_dev_data; } + + bool const_value() const { return m_const_value; } }; /*! @@ -104,6 +109,7 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // { * device tensors * * This opr is used to speed up inference by packing params together. + * This operator assumes the device tensors are constant. */ MGB_DEFINE_CLS_WITH_SUPER(MultipleDeviceTensorHolderBase, cg::OperatorNodeBase) // { @@ -125,6 +131,8 @@ private: void init_output_comp_node() override; void init_output_static_infer_desc() override; NodeProp* do_make_node_prop() const override; + + SmallVector m_host_values; }; } // namespace intl @@ -249,16 +257,43 @@ MGB_DEFINE_OPR_CLASS(SharedDeviceTensor, intl::SharedDeviceTensorBase) // { public: using Super::Super; - static SymbolVar make(ComputingGraph &graph, - const std::shared_ptr &dev_data, - const OperatorNodeConfig &config = {}); + static SymbolVar make(ComputingGraph& graph, + const std::shared_ptr& dev_data, + bool const_value, + const OperatorNodeConfig& config); + + static SymbolVar make(ComputingGraph& graph, + const std::shared_ptr& dev_data, + const OperatorNodeConfig& config = {}) { + return make(graph, dev_data, false, config); + } + + static SymbolVar make_const( + ComputingGraph& graph, + const std::shared_ptr& dev_data, + const OperatorNodeConfig& config = {}) { + return make(graph, dev_data, true, config); + } /*! * \brief make a SharedDeviceTensor by first coping from host to device + * + * See SharedDeviceTensorBase::SharedDeviceTensorBase for const_value. */ - static SymbolVar make(ComputingGraph &graph, - const HostTensorND &value, - const OperatorNodeConfig &config = {}); + static SymbolVar make(ComputingGraph& graph, const HostTensorND& value, + bool const_value, + const OperatorNodeConfig& config); + + static SymbolVar make(ComputingGraph& graph, const HostTensorND& value, + const OperatorNodeConfig& config = {}) { + return make(graph, value, false, config); + } + + static SymbolVar make_const(ComputingGraph& graph, + const HostTensorND& value, + const OperatorNodeConfig& config = {}) { + return make(graph, value, false, config); + } }; /*! @@ -276,7 +311,19 @@ public: static SymbolVar make(ComputingGraph& graph, const std::shared_ptr& dev_data, - const OperatorNodeConfig& config = {}); + bool const_value, const OperatorNodeConfig& config); + + static SymbolVar make(ComputingGraph& graph, + const std::shared_ptr& dev_data, + const OperatorNodeConfig& config = {}) { + return make(graph, dev_data, false, config); + } + + static SymbolVar make_const(ComputingGraph& graph, + const std::shared_ptr& dev_data, + const OperatorNodeConfig& config = {}) { + return make(graph, dev_data, true, config); + } }; /*! @@ -297,6 +344,15 @@ MGB_DEFINE_OPR_CLASS( static SymbolVar make(ComputingGraph &graph, const std::shared_ptr &dev_data, const OperatorNodeConfig &config = {}); + + //! adapter for io.sereg.h: opr_shallow_copy_shared_device_tensor + static SymbolVar make(ComputingGraph& graph, + const std::shared_ptr& dev_data, + bool const_value, + const OperatorNodeConfig& config) { + mgb_assert(!const_value); + return make(graph, dev_data, false, config); + } }; /*! -- GitLab