Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b708f15d
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b708f15d
编写于
4月 08, 2020
作者:
M
Megvii Engine Team
提交者:
Xinran Xu
5月 06, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refactor(mgb/param_pack): use shared mem for param pack
GitOrigin-RevId: bc56f09037e9f7d5118df725a06d94f2c0727242
上级
f18259d7
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
68 addition
and
108 deletion
+68
-108
dnn/include/megdnn/oprs/general.h
dnn/include/megdnn/oprs/general.h
+6
-5
dnn/src/common/param_pack.cpp
dnn/src/common/param_pack.cpp
+6
-23
dnn/test/cuda/param_pack.cpp
dnn/test/cuda/param_pack.cpp
+2
-2
python_module/src/cpp/opr_defs.cpp
python_module/src/cpp/opr_defs.cpp
+5
-5
src/opr/impl/tensor_manip.cpp
src/opr/impl/tensor_manip.cpp
+29
-55
src/opr/impl/tensor_manip.sereg.h
src/opr/impl/tensor_manip.sereg.h
+2
-1
src/opr/include/megbrain/opr/tensor_manip.h
src/opr/include/megbrain/opr/tensor_manip.h
+14
-14
src/opr/test/tensor_manip.cpp
src/opr/test/tensor_manip.cpp
+4
-3
未找到文件。
dnn/include/megdnn/oprs/general.h
浏览文件 @
b708f15d
...
@@ -469,22 +469,23 @@ using Split = SplitForward;
...
@@ -469,22 +469,23 @@ using Split = SplitForward;
* large number of inputs and can handle alignment requirements. Axis is also
* large number of inputs and can handle alignment requirements. Axis is also
* not supported.
* not supported.
*
*
* The
table can be generated by gen_table
(). The \p srcs in ParamPackSplit and
* The
offsets can be generated by gen_offsets
(). The \p srcs in ParamPackSplit and
* \p dsts in ParamPackConcat must be on CPU, and must remain valid until the
* \p dsts in ParamPackConcat must be on CPU, and must remain valid until the
* execution stream is synchronized.
* execution stream is synchronized.
*/
*/
class
ParamPackConcatSplitBase
:
public
OperatorBase
{
class
ParamPackConcatSplitBase
:
public
OperatorBase
{
protected:
protected:
void
check_exec
(
const
TensorLayout
&
concated
,
const
TensorLayout
&
table
,
void
check_exec
(
const
TensorLayout
&
concated
,
const
TensorLayout
&
offsets
,
const
TensorLayout
&
parts
);
const
TensorLayout
&
parts
);
public:
public:
using
Param
=
megdnn
::
param
::
Empty
;
using
Param
=
megdnn
::
param
::
Empty
;
ParamPackConcatSplitBase
(
Handle
*
handle
)
:
OperatorBase
(
handle
)
{}
ParamPackConcatSplitBase
(
Handle
*
handle
)
:
OperatorBase
(
handle
)
{}
//! generate table to be used with ParamPackConcat and ParamPackSplit
//! generate offsets to be used with ParamPackConcat and ParamPackSplit
static
std
::
vector
<
dt_int32
>
gen_table
(
const
TensorShapeArray
&
shapes
,
static
std
::
vector
<
dt_int32
>
gen_offsets
(
const
TensorShapeArray
&
shapes
,
size_t
alignment
,
size_t
dtype_size
);
size_t
alignment
,
size_t
dtype_size
);
};
};
/**
/**
...
...
dnn/src/common/param_pack.cpp
浏览文件 @
b708f15d
...
@@ -29,7 +29,7 @@ void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated,
...
@@ -29,7 +29,7 @@ void ParamPackConcatSplitBase::check_exec(const TensorLayout& concated,
"concated=%zu table=%zu"
,
concated
.
shape
[
0
],
table
.
shape
[
0
]);
"concated=%zu table=%zu"
,
concated
.
shape
[
0
],
table
.
shape
[
0
]);
}
}
std
::
vector
<
dt_int32
>
ParamPackConcatSplitBase
::
gen_
table
(
std
::
vector
<
dt_int32
>
ParamPackConcatSplitBase
::
gen_
offsets
(
const
TensorShapeArray
&
shapes
,
size_t
alignment
,
size_t
dtype_size
)
{
const
TensorShapeArray
&
shapes
,
size_t
alignment
,
size_t
dtype_size
)
{
megdnn_assert
(
alignment
&&
(
alignment
&
(
alignment
-
1
))
==
0
,
megdnn_assert
(
alignment
&&
(
alignment
&
(
alignment
-
1
))
==
0
,
"alignment must be power of 2: %zu"
,
alignment
);
"alignment must be power of 2: %zu"
,
alignment
);
...
@@ -46,30 +46,13 @@ std::vector<dt_int32> ParamPackConcatSplitBase::gen_table(
...
@@ -46,30 +46,13 @@ std::vector<dt_int32> ParamPackConcatSplitBase::gen_table(
return
v
+
((
alignment
-
mod
)
&
(
alignment
-
1
));
return
v
+
((
alignment
-
mod
)
&
(
alignment
-
1
));
};
};
std
::
vector
<
dt_int32
>
offsets
(
shapes
.
size
());
size_t
offset
=
0
;
size_t
offset
=
0
;
for
(
auto
&&
i
:
shapes
)
{
for
(
size_t
i
=
0
;
i
<
shapes
.
size
();
i
++
)
{
offset
=
get_aligned
(
offset
)
+
i
.
total_nr_elems
();
offsets
[
i
]
=
offset
;
offset
=
get_aligned
(
offset
)
+
shapes
[
i
].
total_nr_elems
();
}
}
return
offsets
;
std
::
vector
<
dt_int32
>
table
(
offset
*
2
);
auto
outer_table
=
table
.
data
(),
inner_table
=
outer_table
+
offset
;
offset
=
0
;
for
(
size_t
i
=
0
;
i
<
shapes
.
size
();
++
i
)
{
auto
aligned
=
get_aligned
(
offset
);
for
(
size_t
j
=
offset
;
j
<
aligned
;
++
j
)
{
inner_table
[
j
]
=
outer_table
[
j
]
=
-
1
;
}
offset
=
aligned
;
auto
cur_size
=
shapes
[
i
].
total_nr_elems
();
for
(
size_t
j
=
0
;
j
<
cur_size
;
++
j
)
{
outer_table
[
offset
+
j
]
=
i
;
inner_table
[
offset
+
j
]
=
j
;
}
offset
+=
cur_size
;
}
megdnn_assert
(
offset
*
2
==
table
.
size
());
return
table
;
}
}
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/test/cuda/param_pack.cpp
浏览文件 @
b708f15d
...
@@ -112,8 +112,8 @@ void test_param_pack_split(Handle* handle, const TensorShapeArray& shapes,
...
@@ -112,8 +112,8 @@ void test_param_pack_split(Handle* handle, const TensorShapeArray& shapes,
std
::
vector
<
int32_t
>
table
=
std
::
vector
<
int32_t
>
table
=
create_table
<
T
>
(
shapes
,
handle
->
alignment_requirement
());
create_table
<
T
>
(
shapes
,
handle
->
alignment_requirement
());
ASSERT_EQ
(
table
,
ASSERT_EQ
(
table
,
ParamPackSplit
::
gen_
table
(
shapes
,
handle
->
alignment_requirement
(),
ParamPackSplit
::
gen_
offsets
(
sizeof
(
T
)));
shapes
,
handle
->
alignment_requirement
(),
sizeof
(
T
)));
size_t
pack_size
=
table
.
size
()
/
2
;
size_t
pack_size
=
table
.
size
()
/
2
;
int32_t
*
table_gpu
=
create_device_data
<
int32_t
>
(
handle
,
table
.
data
(),
int32_t
*
table_gpu
=
create_device_data
<
int32_t
>
(
handle
,
table
.
data
(),
table
.
size
());
table
.
size
());
...
...
python_module/src/cpp/opr_defs.cpp
浏览文件 @
b708f15d
...
@@ -47,19 +47,19 @@ SymbolVarArray _Opr::param_pack_split(
...
@@ -47,19 +47,19 @@ SymbolVarArray _Opr::param_pack_split(
shapearr
[
i
]
=
npy
::
vec2shape
(
shapes
[
i
]);
shapearr
[
i
]
=
npy
::
vec2shape
(
shapes
[
i
]);
}
}
auto
cn
=
src
.
node
()
->
comp_node
();
auto
table_val
=
megdnn
::
ParamPackSplit
::
gen_offsets
(
shapearr
,
cn
.
get_mem_addr_alignment
(),
src
.
dtype
().
size
());
if
(
!
table
.
node
())
{
if
(
!
table
.
node
())
{
auto
cn
=
src
.
node
()
->
comp_node
();
if
(
config
.
has_comp_node_set
())
{
if
(
config
.
has_comp_node_set
())
{
cn
=
config
.
get_single_comp_node
();
cn
=
config
.
get_single_comp_node
();
}
}
auto
table_val
=
megdnn
::
ParamPackSplit
::
gen_table
(
HostTensorND
hv
{
cn
,
TensorShape
{{
table_val
.
size
()}},
dtype
::
Int32
{}};
shapearr
,
cn
.
get_mem_addr_alignment
(),
src
.
dtype
().
size
());
HostTensorND
hv
{
cn
,
TensorShape
{
table_val
.
size
()},
dtype
::
Int32
{}};
memcpy
(
hv
.
raw_ptr
(),
table_val
.
data
(),
table_val
.
size
()
*
sizeof
(
int
));
memcpy
(
hv
.
raw_ptr
(),
table_val
.
data
(),
table_val
.
size
()
*
sizeof
(
int
));
table
=
opr
::
ImmutableTensor
::
make
(
*
src
.
node
()
->
owner_graph
(),
hv
);
table
=
opr
::
ImmutableTensor
::
make
(
*
src
.
node
()
->
owner_graph
(),
hv
);
}
}
return
mgb
::
opr
::
ParamPackSplit
::
make
(
src
,
table
,
shapearr
,
config
);
return
mgb
::
opr
::
ParamPackSplit
::
make
(
src
,
table
,
table_val
,
shapearr
,
config
);
}
}
#if MGB_ENABLE_OPR_MM
#if MGB_ENABLE_OPR_MM
...
...
src/opr/impl/tensor_manip.cpp
浏览文件 @
b708f15d
...
@@ -1430,20 +1430,22 @@ void ParamPackConcat::on_output_comp_node_stream_changed(){
...
@@ -1430,20 +1430,22 @@ void ParamPackConcat::on_output_comp_node_stream_changed(){
/* f{{{ ======================= ParamPackSplit ======================= */
/* f{{{ ======================= ParamPackSplit ======================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ParamPackSplit
);
MGB_DYN_TYPE_OBJ_FINAL_IMPL
(
ParamPackSplit
);
ParamPackSplit
::
ParamPackSplit
(
VarNode
*
src
,
VarNode
*
table
,
ParamPackSplit
::
ParamPackSplit
(
VarNode
*
src
,
VarNode
*
offsets
,
TensorShapeArray
&
shapes
,
const
OperatorNodeConfig
&
config
)
const
std
::
vector
<
dt_int32
>
offsets_val
,
:
Super
{
src
->
owner_graph
(),
config
,
"ParamPackSplit"
,
{
src
,
table
}},
TensorShapeArray
&
shapes
,
m_shapes
(
shapes
){
const
OperatorNodeConfig
&
config
)
mgb_assert
(
src
->
comp_node
()
==
table
->
comp_node
());
:
Super
{
src
->
owner_graph
(),
config
,
"ParamPackSplit"
,
{
src
,
offsets
}},
m_shapes
(
shapes
),
m_offsets
(
offsets_val
)
{
mgb_assert
(
src
->
comp_node
()
==
offsets
->
comp_node
());
add_input
({
src
});
add_input
({
src
});
add_input
({
table
});
add_input
({
offsets
});
m_mem_fwd_success
.
resize
(
m_shapes
.
size
());
for
(
size_t
i
=
0
;
i
<
shapes
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
shapes
.
size
();
i
++
)
{
mgb_assert
(
shapes
[
i
].
total_nr_elems
(),
"empty param is not allowed!"
);
mgb_assert
(
shapes
[
i
].
total_nr_elems
(),
"empty param is not allowed!"
);
add_output
(
ssprintf
(
"param_pack_o%zu"
,
i
))
->
dtype
(
src
->
dtype
());
add_output
(
ssprintf
(
"param_pack_o%zu"
,
i
))
->
dtype
(
src
->
dtype
()).
shape
(
shapes
[
i
]);
}
}
cg
::
add_workspace_output
(
this
);
}
}
void
ParamPackSplit
::
add_input_layout_constraint
(){
void
ParamPackSplit
::
add_input_layout_constraint
(){
...
@@ -1451,17 +1453,19 @@ void ParamPackSplit::add_input_layout_constraint(){
...
@@ -1451,17 +1453,19 @@ void ParamPackSplit::add_input_layout_constraint(){
}
}
SymbolVarArray
ParamPackSplit
::
make
(
const
SymbolVar
&
src
,
SymbolVarArray
ParamPackSplit
::
make
(
const
SymbolVar
&
src
,
const
SymbolVar
&
table
,
const
SymbolVar
&
offsets
,
const
std
::
vector
<
dt_int32
>
offsets_val
,
TensorShapeArray
shapes
,
TensorShapeArray
shapes
,
const
OperatorNodeConfig
&
config
)
{
const
OperatorNodeConfig
&
config
)
{
auto
&&
out
=
src
.
node
()
auto
&&
out
=
src
.
node
()
->
owner_graph
()
->
owner_graph
()
->
insert_opr
(
std
::
make_unique
<
ParamPackSplit
>
(
->
insert_opr
(
std
::
make_unique
<
ParamPackSplit
>
(
src
.
node
(),
table
.
node
(),
shapes
,
config
))
src
.
node
(),
offsets
.
node
(),
offsets_val
,
shapes
,
config
))
->
output
();
->
output
();
SymbolVarArray
ret
;
SymbolVarArray
ret
;
ret
.
resize
(
out
.
size
()
-
1
);
// do not return workspace
ret
.
resize
(
out
.
size
()
);
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
ret
[
i
]
=
out
[
i
];
ret
[
i
]
=
out
[
i
];
}
}
...
@@ -1469,41 +1473,25 @@ SymbolVarArray ParamPackSplit::make(const SymbolVar& src,
...
@@ -1469,41 +1473,25 @@ SymbolVarArray ParamPackSplit::make(const SymbolVar& src,
}
}
void
ParamPackSplit
::
scn_do_execute
()
{
void
ParamPackSplit
::
scn_do_execute
()
{
mgb_assert
(
m_opr
.
comp_node
()
==
comp_node
());
megdnn
::
TensorND
src
=
input
(
0
)
->
dev_tensor
().
as_megdnn
(),
table
=
input
(
1
)
->
dev_tensor
().
as_megdnn
();
auto
outputs
=
output
();
m_inp_ptr
.
resize
(
outputs
.
size
()
-
1
);
auto
ptr
=
m_inp_ptr
.
data
();
for
(
size_t
i
=
0
;
i
<
outputs
.
size
()
-
1
;
i
++
)
{
ptr
[
i
]
=
outputs
[
i
]
->
dev_tensor
().
as_megdnn
().
raw_ptr
;
}
megdnn
::
TensorND
dsts
(
ptr
,
megdnn
::
TensorLayout
({
outputs
.
size
()
-
1
},
dtype
::
Int32
()));
m_opr
->
exec
(
src
,
table
,
dsts
,
get_megdnn_workspace_from_var
(
outputs
.
back
()));
}
void
ParamPackSplit
::
on_output_comp_node_stream_changed
()
{
Super
::
on_output_comp_node_stream_changed
();
init_megdnn_opr
();
}
void
ParamPackSplit
::
init_megdnn_opr
(){
m_opr
=
intl
::
create_megdnn_opr
<
megdnn
::
ParamPackSplit
>
(
comp_node
());
}
}
void
ParamPackSplit
::
init_output_dtype
()
{
void
ParamPackSplit
::
init_output_dtype
()
{
// already initialized in constructor
// already initialized in constructor
}
}
void
ParamPackSplit
::
mem_plan_fwd_in2out_readonly
()
{
mgb_assert
(
m_offsets
.
size
()
==
output
().
size
());
for
(
size_t
i
=
0
;
i
<
output
().
size
();
i
++
)
{
auto
layout
=
output
(
i
)
->
layout
();
auto
spec
=
SubTensorSpec
::
make_from_offset_elem
(
layout
,
m_offsets
[
i
]);
m_mem_fwd_success
[
i
]
=
output
(
i
)
->
set_fwd_in2out_readonly
(
input
(
0
),
spec
);
mgb_assert
(
m_mem_fwd_success
[
i
]);
}
}
bool
ParamPackSplit
::
infer_shape
(
size_t
index
,
TensorShape
&
dest
,
bool
ParamPackSplit
::
infer_shape
(
size_t
index
,
TensorShape
&
dest
,
const
cg
::
static_infer
::
InpVal
&
inp
)
{
const
cg
::
static_infer
::
InpVal
&
inp
)
{
if
(
!
m_opr
.
get
()){
init_megdnn_opr
();
}
dest
=
m_shapes
[
index
];
dest
=
m_shapes
[
index
];
return
true
;
return
true
;
}
}
...
@@ -1515,33 +1503,19 @@ void ParamPackSplit::init_output_static_infer_desc() {
...
@@ -1515,33 +1503,19 @@ void ParamPackSplit::init_output_static_infer_desc() {
DepVal
shp_deps
{{
input
(
0
),
DepType
::
SHAPE
},
{
input
(
1
),
DepType
::
SHAPE
}};
DepVal
shp_deps
{{
input
(
0
),
DepType
::
SHAPE
},
{
input
(
1
),
DepType
::
SHAPE
}};
auto
infer_wk
=
[
this
](
TensorShape
&
dst
,
const
InpVal
&
inp
){
for
(
size_t
i
=
0
;
i
<
output
().
size
();
i
++
)
{
dst
.
ndim
=
1
;
if
(
!
m_opr
.
get
()){
init_megdnn_opr
();
}
dst
.
shape
[
0
]
=
m_opr
->
get_workspace_in_bytes
(
inp
.
val
.
at
(
0
).
shape
(),
inp
.
val
.
at
(
1
).
shape
(),
m_shapes
);
return
true
;
};
for
(
size_t
i
=
0
;
i
<
output
().
size
()
-
1
;
i
++
)
{
auto
ov
=
output
(
i
);
auto
ov
=
output
(
i
);
mgr
.
register_shape_infer
(
mgr
.
register_shape_infer
(
ov
,
{
SourceType
::
DEP
,
shp_deps
,
ov
,
{
SourceType
::
DEP
,
shp_deps
,
std
::
bind
(
&
ParamPackSplit
::
infer_shape
,
this
,
i
,
_1
,
_2
)});
std
::
bind
(
&
ParamPackSplit
::
infer_shape
,
this
,
i
,
_1
,
_2
)});
}
}
mgr
.
register_shape_infer
(
output
().
back
(),
{
SourceType
::
DEP
,
shp_deps
,
infer_wk
});
}
}
MGB_IMPL_OPR_GRAD
(
ParamPackSplit
)
{
MGB_IMPL_OPR_GRAD
(
ParamPackSplit
)
{
mgb_assert
(
out_grad
.
size
()
==
opr
.
output
().
size
());
mgb_assert
(
out_grad
.
size
()
==
opr
.
output
().
size
());
SmallVector
<
SymbolVar
>
grad
;
SmallVector
<
SymbolVar
>
grad
;
// last var is workspace, ignore it
// last var is workspace, ignore it
for
(
size_t
i
=
0
;
i
<
out_grad
.
size
()
-
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
out_grad
.
size
();
++
i
)
{
auto
gval
=
out_grad
[
i
];
auto
gval
=
out_grad
[
i
];
if
(
!
gval
)
{
if
(
!
gval
)
{
gval
=
SymbolVar
{
opr
.
output
(
i
)}.
fill_retain_dtype
(
0
).
node
();
gval
=
SymbolVar
{
opr
.
output
(
i
)}.
fill_retain_dtype
(
0
).
node
();
...
...
src/opr/impl/tensor_manip.sereg.h
浏览文件 @
b708f15d
...
@@ -185,9 +185,10 @@ namespace opr {
...
@@ -185,9 +185,10 @@ namespace opr {
const
cg
::
OperatorNodeBase
&
opr_
,
const
VarNodeArray
&
inputs
,
const
cg
::
OperatorNodeBase
&
opr_
,
const
VarNodeArray
&
inputs
,
const
OperatorNodeConfig
&
config
){
const
OperatorNodeConfig
&
config
){
auto
&&
opr
=
opr_
.
cast_final_safe
<
ParamPackSplit
>
();
auto
&&
opr
=
opr_
.
cast_final_safe
<
ParamPackSplit
>
();
auto
&&
offsets
=
opr
.
get_offsets
();
auto
&&
shape
=
opr
.
get_output_shapes
();
auto
&&
shape
=
opr
.
get_output_shapes
();
return
ParamPackSplit
::
make
(
inputs
[
0
],
inputs
[
1
],
shape
,
config
).
at
(
0
).
return
ParamPackSplit
::
make
(
inputs
[
0
],
inputs
[
1
],
offsets
,
shape
,
config
).
at
(
0
).
node
()
->
owner_opr
();
node
()
->
owner_opr
();
}
}
...
...
src/opr/include/megbrain/opr/tensor_manip.h
浏览文件 @
b708f15d
...
@@ -570,31 +570,31 @@ public:
...
@@ -570,31 +570,31 @@ public:
* \brief Opr used to split parameter
* \brief Opr used to split parameter
*/
*/
MGB_DEFINE_OPR_CLASS
(
ParamPackSplit
,
cg
::
SingleCNOperatorNodeBase
)
// {
MGB_DEFINE_OPR_CLASS
(
ParamPackSplit
,
cg
::
SingleCNOperatorNodeBase
)
// {
//! input pointer buffer
SmallVector
<
void
*>
m_inp_ptr
;
intl
::
UniqPtrWithCN
<
megdnn
::
ParamPackSplit
>
m_opr
;
TensorShapeArray
m_shapes
;
TensorShapeArray
m_shapes
;
std
::
vector
<
dt_int32
>
m_offsets
;
std
::
vector
<
bool
>
m_mem_fwd_success
;
void
scn_do_execute
()
override
;
void
scn_do_execute
()
override
;
void
init_output_static_infer_desc
()
override
;
void
init_output_static_infer_desc
()
override
;
void
on_output_comp_node_stream_changed
()
override
;
bool
infer_shape
(
size_t
index
,
TensorShape
&
dest
,
bool
infer_shape
(
size_t
index
,
TensorShape
&
dest
,
const
cg
::
static_infer
::
InpVal
&
inp
);
const
cg
::
static_infer
::
InpVal
&
inp
);
void
init_output_dtype
()
override
;
void
init_output_dtype
()
override
;
void
mem_plan_fwd_in2out_readonly
()
override
;
void
add_input_layout_constraint
()
override
;
void
add_input_layout_constraint
()
override
;
void
init_megdnn_opr
();
public
:
public
:
ParamPackSplit
(
VarNode
*
src
,
VarNode
*
table
,
TensorShapeArray
&
shapes
,
ParamPackSplit
(
VarNode
*
src
,
VarNode
*
offsets
,
const
OperatorNodeConfig
&
config
);
const
std
::
vector
<
dt_int32
>
offsets_val
,
TensorShapeArray
&
shapes
,
const
OperatorNodeConfig
&
config
);
static
SymbolVarArray
make
(
const
SymbolVar
&
src
,
const
SymbolVar
&
offsets
,
const
std
::
vector
<
dt_int32
>
offsets_val
,
TensorShapeArray
shapes
,
const
OperatorNodeConfig
&
config
=
{});
static
SymbolVarArray
make
(
const
SymbolVar
&
src
,
const
SymbolVar
&
table
,
const
std
::
vector
<
dt_int32
>&
get_offsets
()
const
{
TensorShapeArray
shapes
,
const
OperatorNodeConfig
&
config
=
{});
return
m_offsets
;
}
const
TensorShapeArray
&
get_output_shapes
()
const
{
const
TensorShapeArray
&
get_output_shapes
()
const
{
return
m_shapes
;
return
m_shapes
;
...
...
src/opr/test/tensor_manip.cpp
浏览文件 @
b708f15d
...
@@ -1898,7 +1898,7 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){
...
@@ -1898,7 +1898,7 @@ void test_param_pack_concat(const TensorShapeArray &shapes, DType type){
srcs
.
push_back
(
nd
);
srcs
.
push_back
(
nd
);
}
}
auto
host_table_gen
=
megdnn
::
ParamPackSplit
::
gen_
table
(
shapes
,
auto
host_table_gen
=
megdnn
::
ParamPackSplit
::
gen_
offsets
(
shapes
,
cn
.
get_mem_addr_alignment
(),
4
);
cn
.
get_mem_addr_alignment
(),
4
);
ASSERT_EQ
(
host_table_gen
.
size
(),
size
*
2
);
ASSERT_EQ
(
host_table_gen
.
size
(),
size
*
2
);
auto
host_table
=
std
::
make_shared
<
HostTensorND
>
();
auto
host_table
=
std
::
make_shared
<
HostTensorND
>
();
...
@@ -1944,7 +1944,7 @@ void test_param_pack_split(const TensorShapeArray& shapes) {
...
@@ -1944,7 +1944,7 @@ void test_param_pack_split(const TensorShapeArray& shapes) {
auto
make_graph
=
[
&
](
const
typename
Checker
::
SymInpArray
&
inputs
)
->
auto
make_graph
=
[
&
](
const
typename
Checker
::
SymInpArray
&
inputs
)
->
typename
Checker
::
SymOutArray
{
typename
Checker
::
SymOutArray
{
auto
table_val
=
megdnn
::
ParamPackSplit
::
gen_
table
(
auto
table_val
=
megdnn
::
ParamPackSplit
::
gen_
offsets
(
shapes
,
cn
.
get_mem_addr_alignment
(),
4
);
shapes
,
cn
.
get_mem_addr_alignment
(),
4
);
HostTensorND
table
;
HostTensorND
table
;
std
::
copy_n
(
table_val
.
data
(),
table_val
.
size
(),
std
::
copy_n
(
table_val
.
data
(),
table_val
.
size
(),
...
@@ -1954,7 +1954,8 @@ void test_param_pack_split(const TensorShapeArray& shapes) {
...
@@ -1954,7 +1954,8 @@ void test_param_pack_split(const TensorShapeArray& shapes) {
.
ptr
<
dt_int32
>
());
.
ptr
<
dt_int32
>
());
auto
sym_table
=
opr
::
SharedDeviceTensor
::
make
(
auto
sym_table
=
opr
::
SharedDeviceTensor
::
make
(
*
inputs
[
0
].
node
()
->
owner_graph
(),
table
);
*
inputs
[
0
].
node
()
->
owner_graph
(),
table
);
auto
out
=
opr
::
ParamPackSplit
::
make
(
inputs
[
0
],
sym_table
,
shapes
);
auto
out
=
opr
::
ParamPackSplit
::
make
(
inputs
[
0
],
sym_table
,
table_val
,
shapes
);
mgb_assert
(
out
.
size
()
==
nr_out
);
mgb_assert
(
out
.
size
()
==
nr_out
);
typename
Checker
::
SymOutArray
ret
;
typename
Checker
::
SymOutArray
ret
;
for
(
size_t
i
=
0
;
i
<
nr_out
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nr_out
;
++
i
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录