Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
99b17623
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
99b17623
编写于
1月 16, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix(mgb/opr): fix Reduce static value inference
GitOrigin-RevId: 5e5c56064c48eff306f7449f34e7e221510b954b
上级
a3caa5d3
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
42 addition
and
5 deletion
+42
-5
src/opr/impl/basic_arith.cpp
src/opr/impl/basic_arith.cpp
+5
-3
src/opr/include/megbrain/opr/basic_arith.h
src/opr/include/megbrain/opr/basic_arith.h
+1
-0
src/opr/test/basic_arith/reduction.cpp
src/opr/test/basic_arith/reduction.cpp
+36
-2
未找到文件。
src/opr/impl/basic_arith.cpp
浏览文件 @
99b17623
...
...
@@ -1501,8 +1501,9 @@ void Reduce::init_output_static_infer_desc() {
auto
infer_value
=
[
this
](
DeviceTensorND
&
dest
,
const
InpVal
&
inp
)
{
DeviceTensorND
workspace
;
auto
sopr
=
static_infer_opr
.
lock
();
perform
(
m_param
.
mode
,
dest
,
workspace
,
inp
.
val
[
0
].
value
(),
inp
.
val
.
at
(
1
).
shape
(),
sopr
(),
m_param
.
data_type
);
perform
(
m_param
.
mode
,
dest
,
workspace
,
inp
.
val
[
0
].
value
(),
output
(
0
)
->
dtype
(),
inp
.
val
.
at
(
1
).
shape
(),
sopr
(),
m_param
.
data_type
);
return
true
;
};
...
...
@@ -1632,6 +1633,7 @@ void Reduce::perform(
Mode
mode
,
DeviceTensorND
&
dest
,
DeviceTensorND
&
workspace
,
const
DeviceTensorND
&
input
,
const
DType
&
target_dtype
,
const
TensorShape
&
target_shape
,
intl
::
UniqPtrWithCN
<
megdnn
::
Reduce
>
&
opr
,
const
Param
::
DataType
data_type
)
{
...
...
@@ -1674,7 +1676,7 @@ void Reduce::perform(
}
opr
.
comp_node
().
activate
();
dest
.
comp_node
(
opr
.
comp_node
()).
dtype
(
input
.
dtype
()
).
resize
(
target_shape
);
dest
.
comp_node
(
opr
.
comp_node
()).
dtype
(
target_dtype
).
resize
(
target_shape
);
ksched
.
update_ptr
(
*
input_contig
,
dest
,
workspace
);
ksched
.
execute
(
opr
.
get
(),
*
input_contig
,
dest
);
}
...
...
src/opr/include/megbrain/opr/basic_arith.h
浏览文件 @
99b17623
...
...
@@ -304,6 +304,7 @@ MGB_DEFINE_OPR_CLASS(Reduce, intl::DynamicOutputIfInputDynamic<
static
void
perform
(
Mode
mode
,
DeviceTensorND
&
dest
,
DeviceTensorND
&
workspace
,
const
DeviceTensorND
&
input
,
const
DType
&
target_dtype
,
const
TensorShape
&
target_shape
,
intl
::
UniqPtrWithCN
<
megdnn
::
Reduce
>&
opr
,
const
Param
::
DataType
data_type
=
Param
::
DataType
::
DEFAULT
);
...
...
src/opr/test/basic_arith/reduction.cpp
浏览文件 @
99b17623
...
...
@@ -298,7 +298,8 @@ namespace {
static_calc_x
.
copy_from
(
*
host_x
);
opr
::
Reduce
::
perform
(
Mode
::
SUM
,
static_calc_y
,
static_calc_workspace
,
static_calc_x
,
oshp
,
static_calc_opr
);
static_calc_x
,
dtype
::
Float32
(),
oshp
,
static_calc_opr
);
host_y
.
ptr
<
float
>
()[
0
]
++
;
host_y
.
copy_from
(
static_calc_y
);
MGB_ASSERT_TENSOR_NEAR
(
expected
,
host_y
,
1e-5
);
...
...
@@ -468,7 +469,8 @@ TEST(TestBasicArithReduction, NonContPerform) {
for
(
auto
&&
tshp
:
TensorShapeArray
{{
5
,
1
},
{
1
,
5
},
{
1
,
1
},
{
1
},
{
5
,
5
}})
{
opr
::
Reduce
::
perform
(
mode
,
y
,
workspace
,
x
,
tshp
,
opr
);
opr
::
Reduce
::
perform
(
mode
,
y
,
workspace
,
x
,
dtype
::
Float32
(),
tshp
,
opr
);
ASSERT_TRUE
(
y
.
layout
().
is_contiguous
());
ASSERT_EQ
(
tshp
,
y
.
shape
());
size_t
nr
=
tshp
.
total_nr_elems
();
...
...
@@ -866,4 +868,36 @@ TEST(TestBasicArithReduction, StaticInferValue) {
MGB_ASSERT_TENSOR_EQ
(
inferred
,
expected
);
}
TEST
(
TestBasicArithReduction
,
StaticInferValueDType
)
{
using
ParamType
=
opr
::
Reduce
::
Param
::
DataType
;
DType
F32
=
dtype
::
Float32
(),
F16
=
dtype
::
Float16
();
auto
run_test
=
[](
const
DType
&
itype
,
const
DType
&
expected_otype
,
ParamType
param_dtype
)
{
HostTensorGenerator
<>
gen
;
auto
host_x
=
gen
({
2
,
3
,
4
,
5
});
auto
host_tshp
=
std
::
make_shared
<
HostTensorND
>
(
host_x
->
comp_node
(),
dtype
::
Int32
());
host_tshp
->
resize
({
1
});
host_tshp
->
ptr
<
int
>
()[
0
]
=
1
;
auto
graph
=
ComputingGraph
::
make
();
auto
x_f32
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_x
),
x
=
opr
::
TypeCvt
::
make
(
x_f32
,
itype
),
tshp
=
opr
::
Host2DeviceCopy
::
make
(
*
graph
,
host_tshp
),
y
=
opr
::
Reduce
::
make
(
x
,
{
opr
::
Reduce
::
Mode
::
SUM
,
MEGDNN_MAX_NDIM
,
param_dtype
},
tshp
);
auto
inferred
=
graph
->
static_infer_manager
().
infer_value
(
y
.
node
());
ASSERT_EQ
(
inferred
.
layout
().
dtype
,
expected_otype
);
};
run_test
(
F32
,
F32
,
ParamType
::
DEFAULT
);
run_test
(
F16
,
F16
,
ParamType
::
DEFAULT
);
run_test
(
F32
,
F32
,
ParamType
::
FLOAT_O32xC32
);
run_test
(
F16
,
F32
,
ParamType
::
FLOAT_O32xC32
);
run_test
(
F32
,
F16
,
ParamType
::
FLOAT_O16xC32
);
run_test
(
F16
,
F16
,
ParamType
::
FLOAT_O16xC32
);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录