diff --git a/dnn/include/megdnn/oprs/general.h b/dnn/include/megdnn/oprs/general.h index 559bcc3fcbd1c37faa5d18d8496310e60a13e7e5..3fddc64d8d6fbe66f247fe73f6b24d60a123208b 100644 --- a/dnn/include/megdnn/oprs/general.h +++ b/dnn/include/megdnn/oprs/general.h @@ -469,22 +469,23 @@ using Split = SplitForward; * large number of inputs and can handle alignment requirements. Axis is also * not supported. * - * The table can be generated by gen_table(). The \p srcs in ParamPackSplit and + * The offsets can be generated by gen_offsets(). The \p srcs in ParamPackSplit and * \p dsts in ParamPackConcat must be on CPU, and must remain valid until the * execution stream is synchronized. */ class ParamPackConcatSplitBase : public OperatorBase { protected: - void check_exec(const TensorLayout& concated, const TensorLayout& table, + void check_exec(const TensorLayout& concated, const TensorLayout& offsets, const TensorLayout& parts); public: using Param = megdnn::param::Empty; ParamPackConcatSplitBase(Handle* handle) : OperatorBase(handle) {} - //! generate table to be used with ParamPackConcat and ParamPackSplit - static std::vector gen_table(const TensorShapeArray& shapes, - size_t alignment, size_t dtype_size); + //! generate offsets to be used with ParamPackConcat and ParamPackSplit + static std::vector gen_offsets(const TensorShapeArray& shapes, + size_t alignment, + size_t dtype_size); }; /** diff --git a/dnn/src/common/param_pack.cpp b/dnn/src/common/param_pack.cpp index e54093b75606e88bf1a46583be2bec295b1e5027..bd5e5f77bb7ef3d5369be5f17d94301bd981e3a0 100644 --- a/dnn/src/common/param_pack.cpp +++ b/dnn/src/common/param_pack.cpp @@ -29,7 +29,7 @@ void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated, "concated=%zu table=%zu", concated.shape[0], table.shape[0]); } -std::vector ParamPackConcatSplitBase::gen_table( +std::vector ParamPackConcatSplitBase::gen_offsets( const TensorShapeArray& shapes, size_t alignment, size_t dtype_size) { megdnn_assert(alignment && (alignment & (alignment - 1)) == 0, "alignment must be power of 2: %zu", alignment); @@ -46,30 +46,13 @@ std::vector ParamPackConcatSplitBase::gen_table( return v + ((alignment - mod) & (alignment - 1)); }; + std::vector offsets(shapes.size()); size_t offset = 0; - for (auto&& i : shapes) { - offset = get_aligned(offset) + i.total_nr_elems(); + for (size_t i = 0; i < shapes.size(); i++) { + offsets[i] = offset; + offset = get_aligned(offset) + shapes[i].total_nr_elems(); } - - std::vector table(offset * 2); - auto outer_table = table.data(), inner_table = outer_table + offset; - - offset = 0; - for (size_t i = 0; i < shapes.size(); ++i) { - auto aligned = get_aligned(offset); - for (size_t j = offset; j < aligned; ++j) { - inner_table[j] = outer_table[j] = -1; - } - offset = aligned; - auto cur_size = shapes[i].total_nr_elems(); - for (size_t j = 0; j < cur_size; ++j) { - outer_table[offset + j] = i; - inner_table[offset + j] = j; - } - offset += cur_size; - } - megdnn_assert(offset * 2 == table.size()); - return table; + return offsets; } // vim: syntax=cpp.doxygen diff --git a/dnn/test/cuda/param_pack.cpp b/dnn/test/cuda/param_pack.cpp index 85f423613693458ae6c1c9bb84e4979990421946..8406e0db5496626294edeaa183b5c3982eb0ab3a 100644 --- a/dnn/test/cuda/param_pack.cpp +++ b/dnn/test/cuda/param_pack.cpp @@ -112,8 +112,8 @@ void test_param_pack_split(Handle* handle, const TensorShapeArray& shapes, std::vector table = create_table(shapes, handle->alignment_requirement()); ASSERT_EQ(table, - ParamPackSplit::gen_table(shapes, handle->alignment_requirement(), - sizeof(T))); + ParamPackSplit::gen_offsets( + shapes, handle->alignment_requirement(), sizeof(T))); size_t pack_size = table.size() / 2; int32_t* table_gpu = create_device_data(handle, table.data(), table.size()); diff --git a/python_module/src/cpp/opr_defs.cpp b/python_module/src/cpp/opr_defs.cpp index 6dc69be54e7414ac64cf36d6f5a2f12ba35d3258..021e513ede14b03ffb5dc3050ca5efa5b5620d49 100644 --- a/python_module/src/cpp/opr_defs.cpp +++ b/python_module/src/cpp/opr_defs.cpp @@ -47,19 +47,19 @@ SymbolVarArray _Opr::param_pack_split( shapearr[i] = npy::vec2shape(shapes[i]); } + auto cn = src.node()->comp_node(); + auto table_val = megdnn::ParamPackSplit::gen_offsets( + shapearr, cn.get_mem_addr_alignment(), src.dtype().size()); if (!table.node()) { - auto cn = src.node()->comp_node(); if (config.has_comp_node_set()) { cn = config.get_single_comp_node(); } - auto table_val = megdnn::ParamPackSplit::gen_table( - shapearr, cn.get_mem_addr_alignment(), src.dtype().size()); - HostTensorND hv{cn, TensorShape{table_val.size()}, dtype::Int32{}}; + HostTensorND hv{cn, TensorShape{{table_val.size()}}, dtype::Int32{}}; memcpy(hv.raw_ptr(), table_val.data(), table_val.size() * sizeof(int)); table = opr::ImmutableTensor::make(*src.node()->owner_graph(), hv); } - return mgb::opr::ParamPackSplit::make(src, table, shapearr, config); + return mgb::opr::ParamPackSplit::make(src, table, table_val, shapearr, config); } #if MGB_ENABLE_OPR_MM diff --git a/src/opr/impl/tensor_manip.cpp b/src/opr/impl/tensor_manip.cpp index a3ba159878da82ab8cc56b122be7a7c9954e244b..76f2ba33cd6be7bb3773dd11df4c9d94293586e6 100644 --- a/src/opr/impl/tensor_manip.cpp +++ b/src/opr/impl/tensor_manip.cpp @@ -1430,20 +1430,22 @@ void ParamPackConcat::on_output_comp_node_stream_changed(){ /* f{{{ ======================= ParamPackSplit ======================= */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit); -ParamPackSplit::ParamPackSplit(VarNode* src, VarNode* table, - TensorShapeArray& shapes, const OperatorNodeConfig& config) - : Super{src->owner_graph(), config, "ParamPackSplit", {src, table}}, - m_shapes(shapes){ - mgb_assert(src->comp_node() == table->comp_node()); +ParamPackSplit::ParamPackSplit(VarNode* src, VarNode* offsets, + const std::vector offsets_val, + TensorShapeArray& shapes, + const OperatorNodeConfig& config) + : Super{src->owner_graph(), config, "ParamPackSplit", {src, offsets}}, + m_shapes(shapes), m_offsets(offsets_val) { + mgb_assert(src->comp_node() == offsets->comp_node()); add_input({src}); - add_input({table}); + add_input({offsets}); + m_mem_fwd_success.resize(m_shapes.size()); for (size_t i = 0; i < shapes.size(); i++) { mgb_assert(shapes[i].total_nr_elems(), "empty param is not allowed!"); - add_output(ssprintf("param_pack_o%zu", i))->dtype(src->dtype()); + add_output(ssprintf("param_pack_o%zu", i)) + ->dtype(src->dtype()).shape(shapes[i]); } - - cg::add_workspace_output(this); } void ParamPackSplit::add_input_layout_constraint(){ @@ -1451,17 +1453,19 @@ void ParamPackSplit::add_input_layout_constraint(){ } SymbolVarArray ParamPackSplit::make(const SymbolVar& src, - const SymbolVar& table, + const SymbolVar& offsets, + const std::vector offsets_val, TensorShapeArray shapes, const OperatorNodeConfig& config) { auto&& out = src.node() ->owner_graph() ->insert_opr(std::make_unique( - src.node(), table.node(), shapes, config)) + src.node(), offsets.node(), offsets_val, + shapes, config)) ->output(); SymbolVarArray ret; - ret.resize(out.size() - 1); // do not return workspace + ret.resize(out.size()); for (size_t i = 0; i < ret.size(); ++i) { ret[i] = out[i]; } @@ -1469,41 +1473,25 @@ SymbolVarArray ParamPackSplit::make(const SymbolVar& src, } void ParamPackSplit::scn_do_execute() { - mgb_assert(m_opr.comp_node() == comp_node()); - megdnn::TensorND src = input(0)->dev_tensor().as_megdnn(), - table = input(1)->dev_tensor().as_megdnn(); - auto outputs = output(); - m_inp_ptr.resize(outputs.size() - 1); - auto ptr = m_inp_ptr.data(); - - for (size_t i = 0; i < outputs.size() - 1; i++) { - ptr[i] = outputs[i]->dev_tensor().as_megdnn().raw_ptr; - } - megdnn::TensorND dsts( - ptr, megdnn::TensorLayout({outputs.size() - 1}, dtype::Int32())); - - m_opr->exec(src, table, dsts, - get_megdnn_workspace_from_var(outputs.back())); -} - -void ParamPackSplit::on_output_comp_node_stream_changed() { - Super::on_output_comp_node_stream_changed(); - init_megdnn_opr(); -} - -void ParamPackSplit::init_megdnn_opr(){ - m_opr = intl::create_megdnn_opr(comp_node()); } void ParamPackSplit::init_output_dtype() { // already initialized in constructor } +void ParamPackSplit::mem_plan_fwd_in2out_readonly() { + mgb_assert(m_offsets.size() == output().size()); + for (size_t i = 0; i < output().size(); i++) { + auto layout = output(i)->layout(); + auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i]); + m_mem_fwd_success[i] = output(i)->set_fwd_in2out_readonly( + input(0), spec); + mgb_assert(m_mem_fwd_success[i]); + } +} + bool ParamPackSplit::infer_shape(size_t index, TensorShape& dest, const cg::static_infer::InpVal& inp) { - if (!m_opr.get()){ - init_megdnn_opr(); - } dest = m_shapes[index]; return true; } @@ -1515,33 +1503,19 @@ void ParamPackSplit::init_output_static_infer_desc() { DepVal shp_deps{{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}}; - auto infer_wk = [this](TensorShape &dst, const InpVal &inp){ - dst.ndim = 1; - - if(!m_opr.get()){ - init_megdnn_opr(); - } - - dst.shape[0] = m_opr->get_workspace_in_bytes( - inp.val.at(0).shape(), inp.val.at(1).shape(), m_shapes); - return true; - }; - - for (size_t i = 0; i < output().size() - 1; i++) { + for (size_t i = 0; i < output().size(); i++) { auto ov = output(i); mgr.register_shape_infer( ov, {SourceType::DEP, shp_deps, std::bind(&ParamPackSplit::infer_shape, this, i, _1, _2)}); } - mgr.register_shape_infer( - output().back(), {SourceType::DEP, shp_deps, infer_wk}); } MGB_IMPL_OPR_GRAD(ParamPackSplit) { mgb_assert(out_grad.size() == opr.output().size()); SmallVector grad; // last var is workspace, ignore it - for (size_t i = 0; i < out_grad.size() - 1; ++i) { + for (size_t i = 0; i < out_grad.size(); ++i) { auto gval = out_grad[i]; if (!gval) { gval = SymbolVar{opr.output(i)}.fill_retain_dtype(0).node(); diff --git a/src/opr/impl/tensor_manip.sereg.h b/src/opr/impl/tensor_manip.sereg.h index 4d8b5982e63e3e3638dc657b28ed089e3a3054ca..4e09bcbfb5c9c82a285a4045cc51cf5e58534d2e 100644 --- a/src/opr/impl/tensor_manip.sereg.h +++ b/src/opr/impl/tensor_manip.sereg.h @@ -185,9 +185,10 @@ namespace opr { const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, const OperatorNodeConfig &config){ auto &&opr = opr_.cast_final_safe(); + auto &&offsets = opr.get_offsets(); auto &&shape = opr.get_output_shapes(); - return ParamPackSplit::make(inputs[0], inputs[1], shape, config).at(0). + return ParamPackSplit::make(inputs[0], inputs[1], offsets, shape, config).at(0). node()->owner_opr(); } diff --git a/src/opr/include/megbrain/opr/tensor_manip.h b/src/opr/include/megbrain/opr/tensor_manip.h index 72b96d5e88cddd63035f6c184646e172484abbc5..b267fb9cef5f1584f38bdb3c51bde482d9adae66 100644 --- a/src/opr/include/megbrain/opr/tensor_manip.h +++ b/src/opr/include/megbrain/opr/tensor_manip.h @@ -570,31 +570,31 @@ public: * \brief Opr used to split parameter */ MGB_DEFINE_OPR_CLASS(ParamPackSplit, cg::SingleCNOperatorNodeBase) // { - //! input pointer buffer - SmallVector m_inp_ptr; - - intl::UniqPtrWithCN m_opr; TensorShapeArray m_shapes; + std::vector m_offsets; + std::vector m_mem_fwd_success; void scn_do_execute() override; void init_output_static_infer_desc() override; - void on_output_comp_node_stream_changed() override; - bool infer_shape(size_t index, TensorShape &dest, const cg::static_infer::InpVal &inp); - void init_output_dtype() override; - + void mem_plan_fwd_in2out_readonly() override; void add_input_layout_constraint() override; - void init_megdnn_opr(); - public: - ParamPackSplit(VarNode* src, VarNode* table, TensorShapeArray& shapes, - const OperatorNodeConfig &config); + ParamPackSplit(VarNode* src, VarNode* offsets, + const std::vector offsets_val, + TensorShapeArray& shapes, const OperatorNodeConfig& config); + + static SymbolVarArray make(const SymbolVar& src, const SymbolVar& offsets, + const std::vector offsets_val, + TensorShapeArray shapes, + const OperatorNodeConfig& config = {}); - static SymbolVarArray make(const SymbolVar &src, const SymbolVar &table, - TensorShapeArray shapes, const OperatorNodeConfig &config = {}); + const std::vector& get_offsets() const { + return m_offsets; + } const TensorShapeArray& get_output_shapes() const { return m_shapes; diff --git a/src/opr/test/tensor_manip.cpp b/src/opr/test/tensor_manip.cpp index 564adb982222bbc36d4752e7bd2d07a45ec0c427..45635aca943864c9263c5f36019c0043326f23fe 100644 --- a/src/opr/test/tensor_manip.cpp +++ b/src/opr/test/tensor_manip.cpp @@ -1898,7 +1898,7 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){ srcs.push_back(nd); } - auto host_table_gen = megdnn::ParamPackSplit::gen_table(shapes, + auto host_table_gen = megdnn::ParamPackSplit::gen_offsets(shapes, cn.get_mem_addr_alignment(), 4); ASSERT_EQ(host_table_gen.size(), size * 2); auto host_table = std::make_shared(); @@ -1944,7 +1944,7 @@ void test_param_pack_split(const TensorShapeArray& shapes) { auto make_graph = [&](const typename Checker::SymInpArray& inputs) -> typename Checker::SymOutArray { - auto table_val = megdnn::ParamPackSplit::gen_table( + auto table_val = megdnn::ParamPackSplit::gen_offsets( shapes, cn.get_mem_addr_alignment(), 4); HostTensorND table; std::copy_n(table_val.data(), table_val.size(), @@ -1954,7 +1954,8 @@ void test_param_pack_split(const TensorShapeArray& shapes) { .ptr()); auto sym_table = opr::SharedDeviceTensor::make( *inputs[0].node()->owner_graph(), table); - auto out = opr::ParamPackSplit::make(inputs[0], sym_table, shapes); + auto out = opr::ParamPackSplit::make(inputs[0], sym_table, table_val, + shapes); mgb_assert(out.size() == nr_out); typename Checker::SymOutArray ret; for (size_t i = 0; i < nr_out; ++i) {