Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
37b67c9b
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
37b67c9b
编写于
4月 10, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
5月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(dnn/parampack): reduce param pack memory use
GitOrigin-RevId: a802a14e8dbb2b291f05862bd9f0a12622d57f0c
上级
b708f15d
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
120 addition
and
109 deletion
+120
-109
dnn/src/common/param_pack.cpp
dnn/src/common/param_pack.cpp
+12
-12
dnn/src/cuda/param_pack/opr_impl.cpp
dnn/src/cuda/param_pack/opr_impl.cpp
+12
-12
dnn/src/cuda/param_pack/param_pack.cu
dnn/src/cuda/param_pack/param_pack.cu
+21
-14
dnn/src/cuda/param_pack/param_pack.cuh
dnn/src/cuda/param_pack/param_pack.cuh
+2
-3
dnn/src/naive/param_pack/opr_impl.cpp
dnn/src/naive/param_pack/opr_impl.cpp
+23
-21
src/opr/impl/tensor_manip.cpp
src/opr/impl/tensor_manip.cpp
+16
-12
src/opr/impl/tensor_manip.sereg.h
src/opr/impl/tensor_manip.sereg.h
+16
-26
src/opr/include/megbrain/opr/tensor_manip.h
src/opr/include/megbrain/opr/tensor_manip.h
+17
-8
src/opr/test/tensor_manip.cpp
src/opr/test/tensor_manip.cpp
+1
-1
未找到文件。
dnn/src/common/param_pack.cpp
浏览文件 @
37b67c9b
...
...
@@ -15,18 +15,16 @@
using
namespace
megdnn
;
void
ParamPackConcatSplitBase
::
check_exec
(
const
TensorLayout
&
concated
,
const
TensorLayout
&
table
,
const
TensorLayout
&
offsets
,
const
TensorLayout
&
parts
)
{
megdnn_assert
(
table
.
dtype
==
dtype
::
Int32
{},
"bad dtype: %s"
,
table
.
dtype
.
name
());
megdnn_assert
(
concated
.
ndim
==
1
&&
table
.
ndim
==
1
&&
parts
.
ndim
==
1
&&
concated
.
stride
[
0
]
==
1
&&
table
.
stride
[
0
]
==
1
&&
megdnn_assert
(
offsets
.
dtype
==
dtype
::
Int32
{},
"bad dtype: %s"
,
offsets
.
dtype
.
name
());
megdnn_assert
(
concated
.
ndim
==
1
&&
offsets
.
ndim
==
1
&&
parts
.
ndim
==
1
&&
concated
.
stride
[
0
]
==
1
&&
offsets
.
stride
[
0
]
==
1
&&
parts
.
stride
[
0
]
==
1
,
"bad layout: concated=%s
table
=%s parts=%s"
,
concated
.
to_string
().
c_str
(),
table
.
to_string
().
c_str
(),
"bad layout: concated=%s
offsets
=%s parts=%s"
,
concated
.
to_string
().
c_str
(),
offsets
.
to_string
().
c_str
(),
parts
.
to_string
().
c_str
());
megdnn_assert
(
table
.
shape
[
0
]
==
concated
.
shape
[
0
]
*
2
,
"concated=%zu table=%zu"
,
concated
.
shape
[
0
],
table
.
shape
[
0
]);
}
std
::
vector
<
dt_int32
>
ParamPackConcatSplitBase
::
gen_offsets
(
...
...
@@ -46,11 +44,13 @@ std::vector<dt_int32> ParamPackConcatSplitBase::gen_offsets(
return
v
+
((
alignment
-
mod
)
&
(
alignment
-
1
));
};
std
::
vector
<
dt_int32
>
offsets
(
shapes
.
size
());
std
::
vector
<
dt_int32
>
offsets
(
shapes
.
size
()
<<
1
);
size_t
offset
=
0
;
for
(
size_t
i
=
0
;
i
<
shapes
.
size
();
i
++
)
{
offsets
[
i
]
=
offset
;
offset
=
get_aligned
(
offset
)
+
shapes
[
i
].
total_nr_elems
();
offset
=
get_aligned
(
offset
);
offsets
[
i
*
2
]
=
offset
;
offset
+=
shapes
[
i
].
total_nr_elems
();
offsets
[
i
*
2
+
1
]
=
offset
;
}
return
offsets
;
}
...
...
dnn/src/cuda/param_pack/opr_impl.cpp
浏览文件 @
37b67c9b
...
...
@@ -24,7 +24,7 @@ size_t ParamPackConcatImpl::get_workspace_in_bytes(const TensorShapeArray& srcs,
template
<
typename
T
>
void
ParamPackConcatImpl
::
exec_internal
(
_megdnn_tensor_in
srcs
,
_megdnn_tensor_in
table
,
_megdnn_tensor_in
offsets
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
size_t
inp_size
=
srcs
.
layout
.
shape
[
0
],
...
...
@@ -35,25 +35,25 @@ void ParamPackConcatImpl::exec_internal(_megdnn_tensor_in srcs,
megdnn_assert_internal
(
src_cpu
);
auto
src_gpu
=
reinterpret_cast
<
const
T
**>
(
workspace
.
raw_ptr
);
auto
table_outer_gpu
=
table
.
ptr
<
int32_t
>
(),
table_inner_gpu
=
table_outer_gpu
+
out_size
;
auto
offsets_gpu
=
offsets
.
ptr
<
int32_t
>
();
cuda_check
(
cudaMemcpyAsync
(
src_gpu
,
src_cpu
,
sizeof
(
const
T
*
)
*
inp_size
,
cudaMemcpyHostToDevice
,
stream
));
param_pack
::
concat_proxy
<
T
>
(
src_gpu
,
dst
.
ptr
<
T
>
(),
out_size
,
table_outer_gpu
,
table_inner
_gpu
,
stream
);
param_pack
::
concat_proxy
<
T
>
(
src_gpu
,
dst
.
ptr
<
T
>
(),
inp_size
,
out_size
,
offsets
_gpu
,
stream
);
}
void
ParamPackConcatImpl
::
exec
(
_megdnn_tensor_in
srcs
,
_megdnn_tensor_in
table
,
void
ParamPackConcatImpl
::
exec
(
_megdnn_tensor_in
srcs
,
_megdnn_tensor_in
offsets
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
check_exec
(
dst
.
layout
,
table
.
layout
,
srcs
.
layout
);
#define cb(DType) \
if (dst.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
exec_internal<ctype>(srcs,
table
, dst, workspace); \
return; \
check_exec
(
dst
.
layout
,
offsets
.
layout
,
srcs
.
layout
);
#define cb(DType)
\
if (dst.layout.dtype == DType()) {
\
using ctype = typename DTypeTrait<DType>::ctype;
\
exec_internal<ctype>(srcs,
offsets
, dst, workspace); \
return;
\
}
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
megdnn_throw
(
"bad type"
);
...
...
dnn/src/cuda/param_pack/param_pack.cu
浏览文件 @
37b67c9b
...
...
@@ -19,17 +19,24 @@ namespace param_pack {
template
<
typename
T
>
__global__
void
concat_kernel
(
const
T
**
srcs
,
T
*
dst
,
const
int32_t
*
table_outer
,
const
int32_t
*
table_inner
,
const
int32_t
*
offsets
,
size_t
srcs_size
,
size_t
total_size
)
{
size_t
addr
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
if
(
addr
<
total_size
)
{
int32_t
i
=
table_outer
[
addr
];
int32_t
idx
=
table_inner
[
addr
];
if
(
idx
!=
-
1
)
dst
[
addr
]
=
srcs
[
i
][
idx
];
else
size_t
l
=
0
,
r
=
srcs_size
-
1
,
mid
;
while
(
l
<
r
)
{
mid
=
(
l
+
r
)
>>
1
;
if
(
offsets
[(
mid
<<
1
)
+
1
]
>
addr
)
{
r
=
mid
;
}
else
{
l
=
mid
+
1
;
}
}
if
(
addr
<
offsets
[
l
<<
1
])
dst
[
addr
]
=
0
;
else
dst
[
addr
]
=
srcs
[
l
][
addr
-
offsets
[
l
<<
1
]];
}
}
...
...
@@ -59,20 +66,20 @@ void split_proxy(const T* src, T** dsts, size_t total_size,
}
template
<
typename
T
>
void
concat_proxy
(
const
T
**
srcs
,
T
*
dst
,
size_t
total_size
,
const
int32_t
*
table_outer
,
c
onst
int32_t
*
table_inner
,
c
udaStream_t
stream
)
{
void
concat_proxy
(
const
T
**
srcs
,
T
*
dst
,
size_t
srcs_size
,
size_t
total_size
,
const
int32_t
*
offsets
,
cudaStream_t
stream
)
{
size_t
NR_BLOCKS
=
DIVUP
(
total_size
,
NR_THREADS
);
concat_kernel
<<<
NR_BLOCKS
,
NR_THREADS
,
0
,
stream
>>>
(
srcs
,
dst
,
table_outer
,
table_inner
,
total_size
);
srcs
,
dst
,
offsets
,
srcs_size
,
total_size
);
after_kernel_launch
();
}
#define INST(T) \
template void concat_proxy<T>(const T**, T*, size_t, \
const int32_t*,
const int32_t*,
\
template void concat_proxy<T>(const T**, T*, size_t,
size_t,
\
const int32_t*,
\
cudaStream_t); \
template void split_proxy<T>(const T*, T**, size_t, \
template void split_proxy<T>(const T*, T**, size_t,
\
const int32_t*, const int32_t*, \
cudaStream_t);
#define cb(DType) INST(typename DTypeTrait<DType>::ctype)
...
...
dnn/src/cuda/param_pack/param_pack.cuh
浏览文件 @
37b67c9b
...
...
@@ -25,9 +25,8 @@ void split_proxy(const T* src, T** dsts, size_t total_size,
cudaStream_t
stream
);
template
<
typename
T
>
void
concat_proxy
(
const
T
**
srcs
,
T
*
dst
,
size_t
total_size
,
const
int32_t
*
table_outer
,
const
int32_t
*
table_inner
,
cudaStream_t
stream
);
void
concat_proxy
(
const
T
**
srcs
,
T
*
dst
,
size_t
srcs_size
,
size_t
total_size
,
const
int32_t
*
offsets
,
cudaStream_t
stream
);
}
// namespace param_pack
}
// namespace cuda
...
...
dnn/src/naive/param_pack/opr_impl.cpp
浏览文件 @
37b67c9b
...
...
@@ -54,38 +54,40 @@ void ParamPackSplitImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in table,
}
template
<
typename
T
>
void
ParamPackConcatImpl
::
exec_internal
(
_megdnn_tensor_in
srcs
,
int32_t
*
table
,
void
ParamPackConcatImpl
::
exec_internal
(
_megdnn_tensor_in
srcs
,
int32_t
*
offsets
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
)
{
size_t
out_size
=
dst
.
layout
.
total_nr_elems
();
auto
srcs_ptr
=
static_cast
<
const
T
**>
(
srcs
.
raw_ptr
);
auto
dst_ptr
=
dst
.
ptr
<
T
>
();
auto
table_outer
=
table
,
table_inner
=
table_outer
+
out_size
;
for
(
size_t
j
=
0
;
j
<
out_size
;
j
++
)
{
int32_t
i
=
table_outer
[
j
];
int32_t
idx
=
table_inner
[
j
];
if
(
idx
!=
-
1
)
dst_ptr
[
j
]
=
srcs_ptr
[
i
][
idx
];
else
dst_ptr
[
j
]
=
0
;
int32_t
last_pos
=
0
;
for
(
size_t
i
=
0
;
i
<
srcs
.
layout
[
0
];
i
++
)
{
int32_t
begin
=
offsets
[
i
*
2
],
end
=
offsets
[
i
*
2
+
1
];
while
(
last_pos
<
begin
)
{
dst_ptr
[
last_pos
]
=
0
;
last_pos
++
;
}
for
(
int32_t
j
=
0
;
j
<
end
-
begin
;
j
++
)
{
dst_ptr
[
begin
+
j
]
=
srcs_ptr
[
i
][
j
];
}
last_pos
=
end
;
}
}
void
ParamPackConcatImpl
::
exec
(
_megdnn_tensor_in
srcs
,
_megdnn_tensor_in
table
,
void
ParamPackConcatImpl
::
exec
(
_megdnn_tensor_in
srcs
,
_megdnn_tensor_in
offsets
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
check_exec
(
dst
.
layout
,
table
.
layout
,
srcs
.
layout
);
auto
table_ptr
=
table
.
ptr
<
int32_t
>
();
check_exec
(
dst
.
layout
,
offsets
.
layout
,
srcs
.
layout
);
auto
offsets_ptr
=
offsets
.
ptr
<
int32_t
>
();
#define cb(DType) \
if (dst.layout.dtype == DType()) { \
using ctype = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
exec_internal<ctype>(srcs,
table
_ptr, dst, workspace)); \
return; \
#define cb(DType)
\
if (dst.layout.dtype == DType()) {
\
using ctype = typename DTypeTrait<DType>::ctype;
\
MEGDNN_DISPATCH_CPU_KERN_OPR(
\
exec_internal<ctype>(srcs,
offsets
_ptr, dst, workspace)); \
return;
\
}
MEGDNN_FOREACH_COMPUTING_DTYPE
(
cb
)
megdnn_throw
(
"bad type"
);
...
...
src/opr/impl/tensor_manip.cpp
浏览文件 @
37b67c9b
...
...
@@ -1339,8 +1339,10 @@ void Concat::init_output_comp_node() {
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ParamPackConcat
);
ParamPackConcat
::
ParamPackConcat
(
VarNodeArray
&
inp
,
VarNode
*
table
,
const
std
::
vector
<
dt_int32
>
offsets_val
,
const
OperatorNodeConfig
&
config
)
:
Super
(
inp
[
0
]
->
owner_graph
(),
config
,
"ParamPackConcat"
,
inp
)
{
:
Super
(
inp
[
0
]
->
owner_graph
(),
config
,
"ParamPackConcat"
,
inp
),
m_offsets
(
offsets_val
)
{
CompNode
cn
=
inp
[
0
]
->
comp_node
();
add_input
({
inp
[
0
]});
for
(
size_t
i
=
1
;
i
<
inp
.
size
();
i
++
)
{
...
...
@@ -1361,14 +1363,16 @@ void ParamPackConcat::add_input_layout_constraint(){
}
}
SymbolVar
ParamPackConcat
::
make
(
const
SmallVector
<
SymbolVar
>
&
inp
,
const
SymbolVar
&
table
,
const
OperatorNodeConfig
&
config
)
{
SymbolVar
ParamPackConcat
::
make
(
const
SmallVector
<
SymbolVar
>&
inp
,
const
SymbolVar
&
offsets
,
const
std
::
vector
<
dt_int32
>
offsets_val
,
const
OperatorNodeConfig
&
config
)
{
VarNodeArray
array
(
inp
.
size
());
for
(
size_t
i
=
0
;
i
<
inp
.
size
();
i
++
)
{
array
[
i
]
=
inp
[
i
].
node
();
}
return
inp
.
front
().
insert_single_output_opr
<
ParamPackConcat
>
(
array
,
table
.
node
()
,
config
);
return
inp
.
front
().
insert_single_output_opr
<
ParamPackConcat
>
(
array
,
offsets
.
node
(),
offsets_val
,
config
);
}
void
ParamPackConcat
::
scn_do_execute
()
{
...
...
@@ -1379,13 +1383,13 @@ void ParamPackConcat::scn_do_execute() {
for
(
size_t
i
=
0
;
i
<
inputs
.
size
()
-
1
;
i
++
)
{
ptr
[
i
]
=
inputs
[
i
]
->
dev_tensor
().
as_megdnn
().
raw_ptr
;
}
auto
table
=
inputs
.
back
()
->
dev_tensor
().
as_megdnn
();
auto
offsets
=
inputs
.
back
()
->
dev_tensor
().
as_megdnn
();
megdnn
::
TensorND
srcs
(
ptr
,
megdnn
::
TensorLayout
({
inputs
.
size
()
-
1
},
dtype
::
Int32
()));
auto
&&
dst
=
output
(
0
)
->
dev_tensor
().
as_megdnn
();
m_opr
->
exec
(
srcs
,
table
,
dst
,
get_megdnn_workspace_from_var
(
output
(
1
)));
m_opr
->
exec
(
srcs
,
offsets
,
dst
,
get_megdnn_workspace_from_var
(
output
(
1
)));
}
void
ParamPackConcat
::
init_output_dtype
()
{
...
...
@@ -1396,8 +1400,8 @@ void ParamPackConcat::init_output_static_infer_desc(){
using
namespace
cg
::
static_infer
;
auto
&&
mgr
=
owner_graph
()
->
static_infer_manager
();
auto
infer_out
=
[](
TensorShape
&
dest
,
const
InpVal
&
inp
)
{
dest
=
{
inp
.
val
.
back
().
shape
().
total_nr_elems
()
/
2
};
auto
infer_out
=
[
this
](
TensorShape
&
dest
,
const
InpVal
&
inp
)
{
dest
=
{
m_offsets
.
back
()
};
return
true
;
};
DepVal
shp_deps
;
...
...
@@ -1480,10 +1484,10 @@ void ParamPackSplit::init_output_dtype() {
}
void
ParamPackSplit
::
mem_plan_fwd_in2out_readonly
()
{
mgb_assert
(
m_offsets
.
size
()
==
output
().
size
());
mgb_assert
(
m_offsets
.
size
()
==
output
().
size
()
*
2
);
for
(
size_t
i
=
0
;
i
<
output
().
size
();
i
++
)
{
auto
layout
=
output
(
i
)
->
layout
();
auto
spec
=
SubTensorSpec
::
make_from_offset_elem
(
layout
,
m_offsets
[
i
]);
auto
spec
=
SubTensorSpec
::
make_from_offset_elem
(
layout
,
m_offsets
[
i
*
2
]);
m_mem_fwd_success
[
i
]
=
output
(
i
)
->
set_fwd_in2out_readonly
(
input
(
0
),
spec
);
mgb_assert
(
m_mem_fwd_success
[
i
]);
...
...
@@ -1524,7 +1528,7 @@ MGB_IMPL_OPR_GRAD(ParamPackSplit) {
}
return
ParamPackConcat
::
make
(
grad
,
opr
.
input
(
1
),
grad
,
opr
.
input
(
1
),
opr
.
get_offsets
(),
OperatorNodeConfig
{}.
follow_comp_node
(
opr
.
input
(
0
)))
.
node
();
}
...
...
src/opr/impl/tensor_manip.sereg.h
浏览文件 @
37b67c9b
...
...
@@ -31,31 +31,6 @@ namespace serialization {
struct
OprMaker
<
opr
::
GetVarShape
,
0
>:
public
OprMakerVariadic
<
opr
::
GetVarShape
>
{};
template
<
>
struct
OprLoadDumpImpl
<
opr
::
ParamPackConcat
,
0
>
{
using
ParamPackConcat
=
opr
::
ParamPackConcat
;
using
Param
=
opr
::
ParamPackConcat
::
Param
;
static
void
dump
(
OprDumpContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr_
)
{
auto
&&
opr
=
opr_
.
cast_final_safe
<
ParamPackConcat
>
();
ctx
.
write_param
<
Param
>
(
opr
.
param
());
}
static
cg
::
OperatorNodeBase
*
load
(
OprLoadContext
&
ctx
,
const
cg
::
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
)
{
auto
param
=
ctx
.
read_param
<
Param
>
();
mgb_assert
(
!
inputs
.
empty
());
SymbolVarArray
ivar
{
inputs
.
size
()
-
1
};
for
(
size_t
i
=
0
;
i
<
inputs
.
size
()
-
1
;
++
i
)
ivar
[
i
]
=
inputs
[
i
];
return
ParamPackConcat
::
make
(
ivar
,
inputs
.
back
(),
param
,
config
).
node
()
->
owner_opr
();
}
};
template
<
>
struct
OprLoadDumpImpl
<
opr
::
Split
,
0
>
{
using
Split
=
opr
::
Split
;
...
...
@@ -151,7 +126,6 @@ namespace opr {
MGB_SEREG_OPR
(
Dimshuffle
,
1
);
MGB_SEREG_OPR
(
AxisAddRemove
,
1
);
MGB_SEREG_OPR
(
Concat
,
0
);
MGB_SEREG_OPR
(
ParamPackConcat
,
0
);
using
GetVarShapeV1
=
opr
::
GetVarShape
;
MGB_SEREG_OPR
(
GetVarShapeV1
,
0
);
using
ReshapeV1
=
opr
::
Reshape
;
...
...
@@ -193,6 +167,22 @@ namespace opr {
}
MGB_REG_OPR_SHALLOW_COPY
(
ParamPackSplit
,
opr_shallow_copy_param_pack_split
);
cg
::
OperatorNodeBase
*
opr_shallow_copy_param_pack_concat
(
const
serialization
::
OprShallowCopyContext
&
ctx
,
const
cg
::
OperatorNodeBase
&
opr_
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
){
auto
&&
opr
=
opr_
.
cast_final_safe
<
ParamPackConcat
>
();
auto
&&
offsets
=
opr
.
get_offsets
();
SymbolVarArray
ivar
{
inputs
.
size
()
-
1
};
for
(
size_t
i
=
0
;
i
<
inputs
.
size
()
-
1
;
++
i
)
ivar
[
i
]
=
inputs
[
i
];
return
ParamPackConcat
::
make
(
ivar
,
inputs
.
back
(),
offsets
,
config
).
node
()
->
owner_opr
();
}
MGB_REG_OPR_SHALLOW_COPY
(
ParamPackConcat
,
opr_shallow_copy_param_pack_concat
);
MGB_SEREG_OPR
(
RelayoutFormat
,
1
);
MGB_SEREG_OPR
(
WinogradFilterPreprocess
,
1
);
}
// namespace opr
...
...
src/opr/include/megbrain/opr/tensor_manip.h
浏览文件 @
37b67c9b
...
...
@@ -539,6 +539,7 @@ MGB_DEFINE_OPR_CLASS(Concat, cg::SingleCNOutshapePureByInshapeOprBase) // {
MGB_DEFINE_OPR_CLASS
(
ParamPackConcat
,
cg
::
SingleCNOperatorNodeBase
)
// {
//! input pointer buffer
SmallVector
<
void
*>
m_inp_ptr
;
std
::
vector
<
dt_int32
>
m_offsets
;
intl
::
UniqPtrWithCN
<
megdnn
::
ParamPackConcat
>
m_opr
;
void
add_input_layout_constraint
()
override
;
...
...
@@ -554,15 +555,23 @@ public:
return
{};
}
ParamPackConcat
(
VarNodeArray
&
inp
,
VarNode
*
table
,
const
OperatorNodeConfig
&
config
);
static
SymbolVar
make
(
const
SmallVector
<
SymbolVar
>
&
inp
,
const
SymbolVar
&
table
,
const
OperatorNodeConfig
&
config
=
{});
ParamPackConcat
(
VarNodeArray
&
inp
,
VarNode
*
offsets
,
const
std
::
vector
<
dt_int32
>
offsets_val
,
const
OperatorNodeConfig
&
config
);
static
SymbolVar
make
(
const
SmallVector
<
SymbolVar
>&
inp
,
const
SymbolVar
&
offsets
,
const
std
::
vector
<
dt_int32
>
offsets_val
,
const
OperatorNodeConfig
&
config
=
{});
static
SymbolVar
make
(
const
SmallVector
<
SymbolVar
>&
inp
,
const
SymbolVar
&
offsets
,
const
std
::
vector
<
dt_int32
>
offsets_val
,
const
Param
&
,
const
OperatorNodeConfig
&
config
)
{
return
make
(
inp
,
offsets
,
offsets_val
,
config
);
}
static
SymbolVar
make
(
const
SmallVector
<
SymbolVar
>
&
inp
,
const
SymbolVar
&
table
,
const
Param
&
,
const
OperatorNodeConfig
&
config
)
{
return
make
(
inp
,
table
,
config
);
const
std
::
vector
<
dt_int32
>&
get_offsets
()
const
{
return
m_offsets
;
}
}
;
...
...
src/opr/test/tensor_manip.cpp
浏览文件 @
37b67c9b
...
...
@@ -1906,7 +1906,7 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){
memcpy
(
host_table
->
raw_ptr
(),
host_table_gen
.
data
(),
size
*
8
);
auto
table
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_table
);
auto
z
=
opr
::
ParamPackConcat
::
make
(
srcs
,
table
);
auto
z
=
opr
::
ParamPackConcat
::
make
(
srcs
,
table
,
host_table_gen
);
HostTensorND
host_z
;
auto
func
=
graph
->
compile
({
make_callback_copy
(
z
,
host_z
)});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录