提交 37b67c9b 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

refactor(dnn/parampack): reduce param pack memory use

GitOrigin-RevId: a802a14e8dbb2b291f05862bd9f0a12622d57f0c
上级 b708f15d
......@@ -15,18 +15,16 @@
using namespace megdnn;
void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated,
const TensorLayout& table,
const TensorLayout& offsets,
const TensorLayout& parts) {
megdnn_assert(table.dtype == dtype::Int32{}, "bad dtype: %s",
table.dtype.name());
megdnn_assert(concated.ndim == 1 && table.ndim == 1 && parts.ndim == 1 &&
concated.stride[0] == 1 && table.stride[0] == 1 &&
megdnn_assert(offsets.dtype == dtype::Int32{}, "bad dtype: %s",
offsets.dtype.name());
megdnn_assert(concated.ndim == 1 && offsets.ndim == 1 && parts.ndim == 1 &&
concated.stride[0] == 1 && offsets.stride[0] == 1 &&
parts.stride[0] == 1,
"bad layout: concated=%s table=%s parts=%s",
concated.to_string().c_str(), table.to_string().c_str(),
"bad layout: concated=%s offsets=%s parts=%s",
concated.to_string().c_str(), offsets.to_string().c_str(),
parts.to_string().c_str());
megdnn_assert(table.shape[0] == concated.shape[0] * 2,
"concated=%zu table=%zu", concated.shape[0], table.shape[0]);
}
std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets(
......@@ -46,11 +44,13 @@ std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets(
return v + ((alignment - mod) & (alignment - 1));
};
std::vector<dt_int32> offsets(shapes.size());
std::vector<dt_int32> offsets(shapes.size() << 1);
size_t offset = 0;
for (size_t i = 0; i < shapes.size(); i++) {
offsets[i] = offset;
offset = get_aligned(offset) + shapes[i].total_nr_elems();
offset = get_aligned(offset);
offsets[i * 2] = offset;
offset += shapes[i].total_nr_elems();
offsets[i * 2 + 1] = offset;
}
return offsets;
}
......
......@@ -24,7 +24,7 @@ size_t ParamPackConcatImpl::get_workspace_in_bytes(const TensorShapeArray& srcs,
template <typename T>
void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs,
_megdnn_tensor_in table,
_megdnn_tensor_in offsets,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
size_t inp_size = srcs.layout.shape[0],
......@@ -35,25 +35,25 @@ void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs,
megdnn_assert_internal(src_cpu);
auto src_gpu = reinterpret_cast<const T**>(workspace.raw_ptr);
auto table_outer_gpu = table.ptr<int32_t>(),
table_inner_gpu = table_outer_gpu + out_size;
auto offsets_gpu = offsets.ptr<int32_t>();
cuda_check(cudaMemcpyAsync(src_gpu, src_cpu, sizeof(const T*) * inp_size,
cudaMemcpyHostToDevice, stream));
param_pack::concat_proxy<T>(src_gpu, dst.ptr<T>(), out_size,
table_outer_gpu, table_inner_gpu, stream);
param_pack::concat_proxy<T>(src_gpu, dst.ptr<T>(), inp_size, out_size,
offsets_gpu, stream);
}
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table,
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs,
_megdnn_tensor_in offsets,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(dst.layout, table.layout, srcs.layout);
#define cb(DType) \
if (dst.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
exec_internal<ctype>(srcs, table, dst, workspace); \
return; \
check_exec(dst.layout, offsets.layout, srcs.layout);
#define cb(DType) \
if (dst.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
exec_internal<ctype>(srcs, offsets, dst, workspace); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
megdnn_throw("bad type");
......
......@@ -19,17 +19,24 @@ namespace param_pack {
template <typename T>
__global__ void concat_kernel(const T** srcs, T* dst,
const int32_t* table_outer,
const int32_t* table_inner,
const int32_t* offsets,
size_t srcs_size,
size_t total_size) {
size_t addr = threadIdx.x + blockIdx.x * blockDim.x;
if (addr < total_size) {
int32_t i = table_outer[addr];
int32_t idx = table_inner[addr];
if (idx != -1)
dst[addr] = srcs[i][idx];
else
size_t l = 0, r = srcs_size - 1, mid;
while (l < r) {
mid = (l + r) >> 1;
if (offsets[(mid << 1) + 1] > addr) {
r = mid;
} else {
l = mid + 1;
}
}
if (addr < offsets[l << 1])
dst[addr] = 0;
else
dst[addr] = srcs[l][addr - offsets[l << 1]];
}
}
......@@ -59,20 +66,20 @@ void split_proxy(const T* src, T** dsts, size_t total_size,
}
template <typename T>
void concat_proxy(const T** srcs, T* dst, size_t total_size,
const int32_t* table_outer,
const int32_t* table_inner, cudaStream_t stream) {
void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size,
const int32_t* offsets,
cudaStream_t stream) {
size_t NR_BLOCKS = DIVUP(total_size, NR_THREADS);
concat_kernel<<<NR_BLOCKS, NR_THREADS, 0, stream>>>(
srcs, dst, table_outer, table_inner, total_size);
srcs, dst, offsets, srcs_size, total_size);
after_kernel_launch();
}
#define INST(T) \
template void concat_proxy<T>(const T**, T*, size_t, \
const int32_t*, const int32_t*, \
template void concat_proxy<T>(const T**, T*, size_t, size_t, \
const int32_t*, \
cudaStream_t); \
template void split_proxy<T>(const T*, T**, size_t, \
template void split_proxy<T>(const T*, T**, size_t, \
const int32_t*, const int32_t*, \
cudaStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
......
......@@ -25,9 +25,8 @@ void split_proxy(const T* src, T** dsts, size_t total_size,
cudaStream_t stream);
template <typename T>
void concat_proxy(const T** srcs, T* dst, size_t total_size,
const int32_t* table_outer,
const int32_t* table_inner, cudaStream_t stream);
void concat_proxy(const T** srcs, T* dst, size_t srcs_size, size_t total_size,
const int32_t* offsets, cudaStream_t stream);
} // namespace param_pack
} // namespace cuda
......
......@@ -54,38 +54,40 @@ void ParamPackSplitImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in table,
}
template <typename T>
void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs, int32_t* table,
void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs,
int32_t* offsets,
_megdnn_tensor_out dst,
_megdnn_workspace) {
size_t out_size = dst.layout.total_nr_elems();
auto srcs_ptr = static_cast<const T**>(srcs.raw_ptr);
auto dst_ptr = dst.ptr<T>();
auto table_outer = table, table_inner = table_outer + out_size;
for (size_t j = 0; j < out_size; j++) {
int32_t i = table_outer[j];
int32_t idx = table_inner[j];
if (idx != -1)
dst_ptr[j] = srcs_ptr[i][idx];
else
dst_ptr[j] = 0;
int32_t last_pos = 0;
for (size_t i = 0; i < srcs.layout[0]; i++) {
int32_t begin = offsets[i * 2], end = offsets[i * 2 + 1];
while (last_pos < begin) {
dst_ptr[last_pos] = 0;
last_pos++;
}
for (int32_t j = 0; j < end - begin; j++) {
dst_ptr[begin + j] = srcs_ptr[i][j];
}
last_pos = end;
}
}
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs, _megdnn_tensor_in table,
void ParamPackConcatImpl::exec(_megdnn_tensor_in srcs,
_megdnn_tensor_in offsets,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(dst.layout, table.layout, srcs.layout);
auto table_ptr = table.ptr<int32_t>();
check_exec(dst.layout, offsets.layout, srcs.layout);
auto offsets_ptr = offsets.ptr<int32_t>();
#define cb(DType) \
if (dst.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
exec_internal<ctype>(srcs, table_ptr, dst, workspace)); \
return; \
#define cb(DType) \
if (dst.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
exec_internal<ctype>(srcs, offsets_ptr, dst, workspace)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
megdnn_throw("bad type");
......
......@@ -1339,8 +1339,10 @@ void Concat::init_output_comp_node() {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ParamPackConcat);
ParamPackConcat::ParamPackConcat(VarNodeArray& inp, VarNode* table,
const std::vector<dt_int32> offsets_val,
const OperatorNodeConfig& config)
: Super(inp[0]->owner_graph(), config, "ParamPackConcat", inp) {
: Super(inp[0]->owner_graph(), config, "ParamPackConcat", inp),
m_offsets(offsets_val) {
CompNode cn = inp[0]->comp_node();
add_input({inp[0]});
for (size_t i = 1; i < inp.size(); i++) {
......@@ -1361,14 +1363,16 @@ void ParamPackConcat::add_input_layout_constraint(){
}
}
SymbolVar ParamPackConcat::make(const SmallVector<SymbolVar> &inp,
const SymbolVar &table, const OperatorNodeConfig& config) {
SymbolVar ParamPackConcat::make(const SmallVector<SymbolVar>& inp,
const SymbolVar& offsets,
const std::vector<dt_int32> offsets_val,
const OperatorNodeConfig& config) {
VarNodeArray array(inp.size());
for (size_t i = 0; i < inp.size(); i++) {
array[i] = inp[i].node();
}
return inp.front().
insert_single_output_opr<ParamPackConcat>(array, table.node(), config);
return inp.front().insert_single_output_opr<ParamPackConcat>(
array, offsets.node(), offsets_val, config);
}
void ParamPackConcat::scn_do_execute() {
......@@ -1379,13 +1383,13 @@ void ParamPackConcat::scn_do_execute() {
for (size_t i = 0; i < inputs.size() - 1; i++) {
ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr;
}
auto table = inputs.back()->dev_tensor().as_megdnn();
auto offsets = inputs.back()->dev_tensor().as_megdnn();
megdnn::TensorND srcs(
ptr, megdnn::TensorLayout({inputs.size() - 1}, dtype::Int32()));
auto&& dst = output(0)->dev_tensor().as_megdnn();
m_opr->exec(srcs, table, dst, get_megdnn_workspace_from_var(output(1)));
m_opr->exec(srcs, offsets, dst, get_megdnn_workspace_from_var(output(1)));
}
void ParamPackConcat::init_output_dtype() {
......@@ -1396,8 +1400,8 @@ void ParamPackConcat::init_output_static_infer_desc(){
using namespace cg::static_infer;
auto &&mgr = owner_graph()->static_infer_manager();
auto infer_out = [](TensorShape &dest, const InpVal &inp) {
dest = {inp.val.back().shape().total_nr_elems()/2};
auto infer_out = [this](TensorShape &dest, const InpVal &inp) {
dest = {m_offsets.back()};
return true;
};
DepVal shp_deps;
......@@ -1480,10 +1484,10 @@ void ParamPackSplit::init_output_dtype() {
}
void ParamPackSplit::mem_plan_fwd_in2out_readonly() {
mgb_assert(m_offsets.size() == output().size());
mgb_assert(m_offsets.size() == output().size() * 2);
for (size_t i = 0; i < output().size(); i++) {
auto layout = output(i)->layout();
auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i]);
auto spec = SubTensorSpec::make_from_offset_elem(layout, m_offsets[i * 2]);
m_mem_fwd_success[i] = output(i)->set_fwd_in2out_readonly(
input(0), spec);
mgb_assert(m_mem_fwd_success[i]);
......@@ -1524,7 +1528,7 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) {
}
return ParamPackConcat::make(
grad, opr.input(1),
grad, opr.input(1), opr.get_offsets(),
OperatorNodeConfig{}.follow_comp_node(opr.input(0)))
.node();
}
......
......@@ -31,31 +31,6 @@ namespace serialization {
struct OprMaker<opr::GetVarShape, 0>:
public OprMakerVariadic<opr::GetVarShape>{};
template<>
struct OprLoadDumpImpl<opr::ParamPackConcat, 0>
{
using ParamPackConcat = opr::ParamPackConcat;
using Param = opr::ParamPackConcat::Param;
static void dump(OprDumpContext &ctx,
const cg::OperatorNodeBase &opr_) {
auto &&opr = opr_.cast_final_safe<ParamPackConcat>();
ctx.write_param<Param>(opr.param());
}
static cg::OperatorNodeBase* load(
OprLoadContext &ctx, const cg::VarNodeArray &inputs,
const OperatorNodeConfig &config) {
auto param = ctx.read_param<Param>();
mgb_assert(!inputs.empty());
SymbolVarArray ivar{inputs.size() - 1};
for (size_t i = 0; i < inputs.size() - 1; ++ i)
ivar[i] = inputs[i];
return ParamPackConcat::make(ivar, inputs.back(),
param, config).node()->owner_opr();
}
};
template<>
struct OprLoadDumpImpl<opr::Split, 0> {
using Split = opr::Split;
......@@ -151,7 +126,6 @@ namespace opr {
MGB_SEREG_OPR(Dimshuffle, 1);
MGB_SEREG_OPR(AxisAddRemove, 1);
MGB_SEREG_OPR(Concat, 0);
MGB_SEREG_OPR(ParamPackConcat, 0);
using GetVarShapeV1 = opr::GetVarShape;
MGB_SEREG_OPR(GetVarShapeV1, 0);
using ReshapeV1 = opr::Reshape;
......@@ -193,6 +167,22 @@ namespace opr {
}
MGB_REG_OPR_SHALLOW_COPY(ParamPackSplit, opr_shallow_copy_param_pack_split);
cg::OperatorNodeBase* opr_shallow_copy_param_pack_concat(
const serialization::OprShallowCopyContext &ctx,
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs,
const OperatorNodeConfig &config){
auto &&opr = opr_.cast_final_safe<ParamPackConcat>();
auto &&offsets = opr.get_offsets();
SymbolVarArray ivar{inputs.size() - 1};
for (size_t i = 0; i < inputs.size() - 1; ++i)
ivar[i] = inputs[i];
return ParamPackConcat::make(ivar, inputs.back(), offsets, config).
node()->owner_opr();
}
MGB_REG_OPR_SHALLOW_COPY(ParamPackConcat, opr_shallow_copy_param_pack_concat);
MGB_SEREG_OPR(RelayoutFormat, 1);
MGB_SEREG_OPR(WinogradFilterPreprocess, 1);
} // namespace opr
......
......@@ -539,6 +539,7 @@ MGB_DEFINE_OPR_CLASS(Concat, cg::SingleCNOutshapePureByInshapeOprBase) // {
MGB_DEFINE_OPR_CLASS(ParamPackConcat, cg::SingleCNOperatorNodeBase) // {
//! input pointer buffer
SmallVector<void*> m_inp_ptr;
std::vector<dt_int32> m_offsets;
intl::UniqPtrWithCN<megdnn::ParamPackConcat> m_opr;
void add_input_layout_constraint() override;
......@@ -554,15 +555,23 @@ public:
return {};
}
ParamPackConcat(VarNodeArray &inp, VarNode *table,
const OperatorNodeConfig &config);
static SymbolVar make(const SmallVector<SymbolVar> &inp,
const SymbolVar &table, const OperatorNodeConfig &config = {});
ParamPackConcat(VarNodeArray& inp, VarNode* offsets,
const std::vector<dt_int32> offsets_val,
const OperatorNodeConfig& config);
static SymbolVar make(const SmallVector<SymbolVar>& inp,
const SymbolVar& offsets,
const std::vector<dt_int32> offsets_val,
const OperatorNodeConfig& config = {});
static SymbolVar make(const SmallVector<SymbolVar>& inp,
const SymbolVar& offsets,
const std::vector<dt_int32> offsets_val, const Param&,
const OperatorNodeConfig& config) {
return make(inp, offsets, offsets_val, config);
}
static SymbolVar make(const SmallVector<SymbolVar> &inp,
const SymbolVar &table, const Param &,
const OperatorNodeConfig &config) {
return make(inp, table, config);
const std::vector<dt_int32>& get_offsets() const {
return m_offsets;
}
};
......
......@@ -1906,7 +1906,7 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){
memcpy(host_table->raw_ptr(), host_table_gen.data(), size * 8);
auto table = opr::Host2DeviceCopy::make(*graph, host_table);
auto z = opr::ParamPackConcat::make(srcs, table);
auto z = opr::ParamPackConcat::make(srcs, table, host_table_gen);
HostTensorND host_z;
auto func = graph->compile({make_callback_copy(z, host_z)});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册