提交 b708f15d 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

refactor(mgb/param_pack): use shared mem for param pack

GitOrigin-RevId: bc56f09037e9f7d5118df725a06d94f2c0727242
上级 f18259d7
...@@ -469,22 +469,23 @@ using Split = SplitForward; ...@@ -469,22 +469,23 @@ using Split = SplitForward;
* large number of inputs and can handle alignment requirements. Axis is also * large number of inputs and can handle alignment requirements. Axis is also
* not supported. * not supported.
* *
* The table can be generated by gen_table(). The \p srcs in ParamPackSplit and * The offsets can be generated by gen_offsets(). The \p srcs in ParamPackSplit and
* \p dsts in ParamPackConcat must be on CPU, and must remain valid until the * \p dsts in ParamPackConcat must be on CPU, and must remain valid until the
* execution stream is synchronized. * execution stream is synchronized.
*/ */
class ParamPackConcatSplitBase : public OperatorBase { class ParamPackConcatSplitBase : public OperatorBase {
protected: protected:
void check_exec(const TensorLayout& concated, const TensorLayout& table, void check_exec(const TensorLayout& concated, const TensorLayout& offsets,
const TensorLayout& parts); const TensorLayout& parts);
public: public:
using Param = megdnn::param::Empty; using Param = megdnn::param::Empty;
ParamPackConcatSplitBase(Handle* handle) : OperatorBase(handle) {} ParamPackConcatSplitBase(Handle* handle) : OperatorBase(handle) {}
//! generate table to be used with ParamPackConcat and ParamPackSplit //! generate offsets to be used with ParamPackConcat and ParamPackSplit
static std::vector<dt_int32> gen_table(const TensorShapeArray& shapes, static std::vector<dt_int32> gen_offsets(const TensorShapeArray& shapes,
size_t alignment, size_t dtype_size); size_t alignment,
size_t dtype_size);
}; };
/** /**
......
...@@ -29,7 +29,7 @@ void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated, ...@@ -29,7 +29,7 @@ void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated,
"concated=%zu table=%zu", concated.shape[0], table.shape[0]); "concated=%zu table=%zu", concated.shape[0], table.shape[0]);
} }
std::vector<dt_int32> ParamPackConcatSplitBase::gen_table( std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets(
const TensorShapeArray& shapes, size_t alignment, size_t dtype_size) { const TensorShapeArray& shapes, size_t alignment, size_t dtype_size) {
megdnn_assert(alignment && (alignment & (alignment - 1)) == 0, megdnn_assert(alignment && (alignment & (alignment - 1)) == 0,
"alignment must be power of 2: %zu", alignment); "alignment must be power of 2: %zu", alignment);
...@@ -46,30 +46,13 @@ std::vector<dt_int32> ParamPackConcatSplitBase::gen_table( ...@@ -46,30 +46,13 @@ std::vector<dt_int32> ParamPackConcatSplitBase::gen_table(
return v + ((alignment - mod) & (alignment - 1)); return v + ((alignment - mod) & (alignment - 1));
}; };
std::vector<dt_int32> offsets(shapes.size());
size_t offset = 0; size_t offset = 0;
for (auto&& i : shapes) { for (size_t i = 0; i < shapes.size(); i++) {
offset = get_aligned(offset) + i.total_nr_elems(); offsets[i] = offset;
offset = get_aligned(offset) + shapes[i].total_nr_elems();
} }
return offsets;
std::vector<dt_int32> table(offset * 2);
auto outer_table = table.data(), inner_table = outer_table + offset;
offset = 0;
for (size_t i = 0; i < shapes.size(); ++i) {
auto aligned = get_aligned(offset);
for (size_t j = offset; j < aligned; ++j) {
inner_table[j] = outer_table[j] = -1;
}
offset = aligned;
auto cur_size = shapes[i].total_nr_elems();
for (size_t j = 0; j < cur_size; ++j) {
outer_table[offset + j] = i;
inner_table[offset + j] = j;
}
offset += cur_size;
}
megdnn_assert(offset * 2 == table.size());
return table;
} }
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -112,8 +112,8 @@ void test_param_pack_split(Handle* handle, const TensorShapeArray& shapes, ...@@ -112,8 +112,8 @@ void test_param_pack_split(Handle* handle, const TensorShapeArray& shapes,
std::vector<int32_t> table = std::vector<int32_t> table =
create_table<T>(shapes, handle->alignment_requirement()); create_table<T>(shapes, handle->alignment_requirement());
ASSERT_EQ(table, ASSERT_EQ(table,
ParamPackSplit::gen_table(shapes, handle->alignment_requirement(), ParamPackSplit::gen_offsets(
sizeof(T))); shapes, handle->alignment_requirement(), sizeof(T)));
size_t pack_size = table.size() / 2; size_t pack_size = table.size() / 2;
int32_t* table_gpu = create_device_data<int32_t>(handle, table.data(), int32_t* table_gpu = create_device_data<int32_t>(handle, table.data(),
table.size()); table.size());
......
...@@ -47,19 +47,19 @@ SymbolVarArray _Opr::param_pack_split( ...@@ -47,19 +47,19 @@ SymbolVarArray _Opr::param_pack_split(
shapearr[i] = npy::vec2shape(shapes[i]); shapearr[i] = npy::vec2shape(shapes[i]);
} }
if (!table.node()) {
auto cn = src.node()->comp_node(); auto cn = src.node()->comp_node();
auto table_val = megdnn::ParamPackSplit::gen_offsets(
shapearr, cn.get_mem_addr_alignment(), src.dtype().size());
if (!table.node()) {
if (config.has_comp_node_set()) { if (config.has_comp_node_set()) {
cn = config.get_single_comp_node(); cn = config.get_single_comp_node();
} }
auto table_val = megdnn::ParamPackSplit::gen_table( HostTensorND hv{cn, TensorShape{{table_val.size()}}, dtype::Int32{}};
shapearr, cn.get_mem_addr_alignment(), src.dtype().size());
HostTensorND hv{cn, TensorShape{table_val.size()}, dtype::Int32{}};
memcpy(hv.raw_ptr(), table_val.data(), table_val.size() * sizeof(int)); memcpy(hv.raw_ptr(), table_val.data(), table_val.size() * sizeof(int));
table = opr::ImmutableTensor::make(*src.node()->owner_graph(), hv); table = opr::ImmutableTensor::make(*src.node()->owner_graph(), hv);
} }
return mgb::opr::ParamPackSplit::make(src, table, shapearr, config); return mgb::opr::ParamPackSplit::make(src, table, table_val, shapearr, config);
} }
#if MGB_ENABLE_OPR_MM #if MGB_ENABLE_OPR_MM
......
...@@ -1430,20 +1430,22 @@ void ParamPackConcat::on_output_comp_node_stream_changed(){ ...@@ -1430,20 +1430,22 @@ void ParamPackConcat::on_output_comp_node_stream_changed(){
/* f{{{ ======================= ParamPackSplit ======================= */ /* f{{{ ======================= ParamPackSplit ======================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit); MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackSplit);
ParamPackSplit::ParamPackSplit(VarNode* src, VarNode* table, ParamPackSplit::ParamPackSplit(VarNode* src, VarNode* offsets,
TensorShapeArray& shapes, const OperatorNodeConfig& config) const std::vector<dt_int32> offsets_val,
: Super{src->owner_graph(), config, "ParamPackSplit", {src, table}}, TensorShapeArray& shapes,
m_shapes(shapes){ const OperatorNodeConfig& config)
mgb_assert(src->comp_node() == table->comp_node()); : Super{src->owner_graph(), config, "ParamPackSplit", {src, offsets}},
m_shapes(shapes), m_offsets(offsets_val) {
mgb_assert(src->comp_node() == offsets->comp_node());
add_input({src}); add_input({src});
add_input({table}); add_input({offsets});
m_mem_fwd_success.resize(m_shapes.size());
for (size_t i = 0; i < shapes.size(); i++) { for (size_t i = 0; i < shapes.size(); i++) {
mgb_assert(shapes[i].total_nr_elems(), "empty param is not allowed!"); mgb_assert(shapes[i].total_nr_elems(), "empty param is not allowed!");
add_output(ssprintf("param_pack_o%zu", i))->dtype(src->dtype()); add_output(ssprintf("param_pack_o%zu", i))
->dtype(src->dtype()).shape(shapes[i]);
} }
cg::add_workspace_output(this);
} }
void ParamPackSplit::add_input_layout_constraint(){ void ParamPackSplit::add_input_layout_constraint(){
...@@ -1451,17 +1453,19 @@ void ParamPackSplit::add_input_layout_constraint(){ ...@@ -1451,17 +1453,19 @@ void ParamPackSplit::add_input_layout_constraint(){
} }
SymbolVarArray ParamPackSplit::make(const SymbolVar& src, SymbolVarArray ParamPackSplit::make(const SymbolVar& src,
const SymbolVar& table, const SymbolVar& offsets,
const std::vector<dt_int32> offsets_val,
TensorShapeArray shapes, TensorShapeArray shapes,
const OperatorNodeConfig& config) { const OperatorNodeConfig& config) {
auto&& out = src.node() auto&& out = src.node()
->owner_graph() ->owner_graph()
->insert_opr(std::make_unique<ParamPackSplit>( ->insert_opr(std::make_unique<ParamPackSplit>(
src.node(), table.node(), shapes, config)) src.node(), offsets.node(), offsets_val,
shapes, config))
->output(); ->output();
SymbolVarArray ret; SymbolVarArray ret;
ret.resize(out.size() - 1); // do not return workspace ret.resize(out.size());
for (size_t i = 0; i < ret.size(); ++i) { for (size_t i = 0; i < ret.size(); ++i) {
ret[i] = out[i]; ret[i] = out[i];
} }
...@@ -1469,41 +1473,25 @@ SymbolVarArray ParamPackSplit::make(const SymbolVar& src, ...@@ -1469,41 +1473,25 @@ SymbolVarArray ParamPackSplit::make(const SymbolVar& src,
} }
void ParamPackSplit::scn_do_execute() { void ParamPackSplit::scn_do_execute() {
mgb_assert(m_opr.comp_node() == comp_node());
megdnn::TensorND src = input(0)->dev_tensor().as_megdnn(),
table = input(1)->dev_tensor().as_megdnn();
auto outputs = output();
m_inp_ptr.resize(outputs.size() - 1);
auto ptr = m_inp_ptr.data();
for (size_t i = 0; i < outputs.size() - 1; i++) {
ptr[i] = outputs[i]->dev_tensor().as_megdnn().raw_ptr;
}
megdnn::TensorND dsts(
ptr, megdnn::TensorLayout({outputs.size() - 1}, dtype::Int32()));
m_opr->exec(src, table, dsts,
get_megdnn_workspace_from_var(outputs.back()));
}
void ParamPackSplit::on_output_comp_node_stream_changed() {
Super::on_output_comp_node_stream_changed();
init_megdnn_opr();
}
void ParamPackSplit::init_megdnn_opr(){
m_opr = intl::create_megdnn_opr<megdnn::ParamPackSplit>(comp_node());
} }
void ParamPackSplit::init_output_dtype() { void ParamPackSplit::init_output_dtype() {
// already initialized in constructor // already initialized in constructor
} }
void ParamPackSplit::mem_plan_fwd_in2out_readonly() {
mgb_assert(m_offsets.size() == output().size());
for (size_t i = 0; i < output().size(); i++) {
auto layout = output(i)->layout();
auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i]);
m_mem_fwd_success[i] = output(i)->set_fwd_in2out_readonly(
input(0), spec);
mgb_assert(m_mem_fwd_success[i]);
}
}
bool ParamPackSplit::infer_shape(size_t index, TensorShape& dest, bool ParamPackSplit::infer_shape(size_t index, TensorShape& dest,
const cg::static_infer::InpVal& inp) { const cg::static_infer::InpVal& inp) {
if (!m_opr.get()){
init_megdnn_opr();
}
dest = m_shapes[index]; dest = m_shapes[index];
return true; return true;
} }
...@@ -1515,33 +1503,19 @@ void ParamPackSplit::init_output_static_infer_desc() { ...@@ -1515,33 +1503,19 @@ void ParamPackSplit::init_output_static_infer_desc() {
DepVal shp_deps{{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}}; DepVal shp_deps{{input(0), DepType::SHAPE}, {input(1), DepType::SHAPE}};
auto infer_wk = [this](TensorShape &dst, const InpVal &inp){ for (size_t i = 0; i < output().size(); i++) {
dst.ndim = 1;
if(!m_opr.get()){
init_megdnn_opr();
}
dst.shape[0] = m_opr->get_workspace_in_bytes(
inp.val.at(0).shape(), inp.val.at(1).shape(), m_shapes);
return true;
};
for (size_t i = 0; i < output().size() - 1; i++) {
auto ov = output(i); auto ov = output(i);
mgr.register_shape_infer( mgr.register_shape_infer(
ov, {SourceType::DEP, shp_deps, ov, {SourceType::DEP, shp_deps,
std::bind(&ParamPackSplit::infer_shape, this, i, _1, _2)}); std::bind(&ParamPackSplit::infer_shape, this, i, _1, _2)});
} }
mgr.register_shape_infer(
output().back(), {SourceType::DEP, shp_deps, infer_wk});
} }
MGB_IMPL_OPR_GRAD(ParamPackSplit) { MGB_IMPL_OPR_GRAD(ParamPackSplit) {
mgb_assert(out_grad.size() == opr.output().size()); mgb_assert(out_grad.size() == opr.output().size());
SmallVector<SymbolVar> grad; SmallVector<SymbolVar> grad;
// last var is workspace, ignore it // last var is workspace, ignore it
for (size_t i = 0; i < out_grad.size() - 1; ++i) { for (size_t i = 0; i < out_grad.size(); ++i) {
auto gval = out_grad[i]; auto gval = out_grad[i];
if (!gval) { if (!gval) {
gval = SymbolVar{opr.output(i)}.fill_retain_dtype(0).node(); gval = SymbolVar{opr.output(i)}.fill_retain_dtype(0).node();
......
...@@ -185,9 +185,10 @@ namespace opr { ...@@ -185,9 +185,10 @@ namespace opr {
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
const OperatorNodeConfig &config){ const OperatorNodeConfig &config){
auto &&opr = opr_.cast_final_safe<ParamPackSplit>(); auto &&opr = opr_.cast_final_safe<ParamPackSplit>();
auto &&offsets = opr.get_offsets();
auto &&shape = opr.get_output_shapes(); auto &&shape = opr.get_output_shapes();
return ParamPackSplit::make(inputs[0], inputs[1], shape, config).at(0). return ParamPackSplit::make(inputs[0], inputs[1], offsets, shape, config).at(0).
node()->owner_opr(); node()->owner_opr();
} }
......
...@@ -570,31 +570,31 @@ public: ...@@ -570,31 +570,31 @@ public:
* \brief Opr used to split parameter * \brief Opr used to split parameter
*/ */
MGB_DEFINE_OPR_CLASS(ParamPackSplit, cg::SingleCNOperatorNodeBase) // { MGB_DEFINE_OPR_CLASS(ParamPackSplit, cg::SingleCNOperatorNodeBase) // {
//! input pointer buffer
SmallVector<void*> m_inp_ptr;
intl::UniqPtrWithCN<megdnn::ParamPackSplit> m_opr;
TensorShapeArray m_shapes; TensorShapeArray m_shapes;
std::vector<dt_int32> m_offsets;
std::vector<bool> m_mem_fwd_success;
void scn_do_execute() override; void scn_do_execute() override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
void on_output_comp_node_stream_changed() override;
bool infer_shape(size_t index, TensorShape &dest, bool infer_shape(size_t index, TensorShape &dest,
const cg::static_infer::InpVal &inp); const cg::static_infer::InpVal &inp);
void init_output_dtype() override; void init_output_dtype() override;
void mem_plan_fwd_in2out_readonly() override;
void add_input_layout_constraint() override; void add_input_layout_constraint() override;
void init_megdnn_opr();
public: public:
ParamPackSplit(VarNode* src, VarNode* table, TensorShapeArray& shapes, ParamPackSplit(VarNode* src, VarNode* offsets,
const OperatorNodeConfig &config); const std::vector<dt_int32> offsets_val,
TensorShapeArray& shapes, const OperatorNodeConfig& config);
static SymbolVarArray make(const SymbolVar& src, const SymbolVar& offsets,
const std::vector<dt_int32> offsets_val,
TensorShapeArray shapes,
const OperatorNodeConfig& config = {});
static SymbolVarArray make(const SymbolVar &src, const SymbolVar &table, const std::vector<dt_int32>& get_offsets() const {
TensorShapeArray shapes, const OperatorNodeConfig &config = {}); return m_offsets;
}
const TensorShapeArray& get_output_shapes() const { const TensorShapeArray& get_output_shapes() const {
return m_shapes; return m_shapes;
......
...@@ -1898,7 +1898,7 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){ ...@@ -1898,7 +1898,7 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){
srcs.push_back(nd); srcs.push_back(nd);
} }
auto host_table_gen = megdnn::ParamPackSplit::gen_table(shapes, auto host_table_gen = megdnn::ParamPackSplit::gen_offsets(shapes,
cn.get_mem_addr_alignment(), 4); cn.get_mem_addr_alignment(), 4);
ASSERT_EQ(host_table_gen.size(), size * 2); ASSERT_EQ(host_table_gen.size(), size * 2);
auto host_table = std::make_shared<HostTensorND>(); auto host_table = std::make_shared<HostTensorND>();
...@@ -1944,7 +1944,7 @@ void test_param_pack_split(const TensorShapeArray& shapes) { ...@@ -1944,7 +1944,7 @@ void test_param_pack_split(const TensorShapeArray& shapes) {
auto make_graph = [&](const typename Checker::SymInpArray& inputs) -> auto make_graph = [&](const typename Checker::SymInpArray& inputs) ->
typename Checker::SymOutArray { typename Checker::SymOutArray {
auto table_val = megdnn::ParamPackSplit::gen_table( auto table_val = megdnn::ParamPackSplit::gen_offsets(
shapes, cn.get_mem_addr_alignment(), 4); shapes, cn.get_mem_addr_alignment(), 4);
HostTensorND table; HostTensorND table;
std::copy_n(table_val.data(), table_val.size(), std::copy_n(table_val.data(), table_val.size(),
...@@ -1954,7 +1954,8 @@ void test_param_pack_split(const TensorShapeArray& shapes) { ...@@ -1954,7 +1954,8 @@ void test_param_pack_split(const TensorShapeArray& shapes) {
.ptr<dt_int32>()); .ptr<dt_int32>());
auto sym_table = opr::SharedDeviceTensor::make( auto sym_table = opr::SharedDeviceTensor::make(
*inputs[0].node()->owner_graph(), table); *inputs[0].node()->owner_graph(), table);
auto out = opr::ParamPackSplit::make(inputs[0], sym_table, shapes); auto out = opr::ParamPackSplit::make(inputs[0], sym_table, table_val,
shapes);
mgb_assert(out.size() == nr_out); mgb_assert(out.size() == nr_out);
typename Checker::SymOutArray ret; typename Checker::SymOutArray ret;
for (size_t i = 0; i < nr_out; ++i) { for (size_t i = 0; i < nr_out; ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册