Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b2827cb1
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b2827cb1
编写于
9月 14, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(opr): let Dot, MatrixMul and BatchedMatrixMul support empty input
GitOrigin-RevId: 10a3c5b106d4013f486d8b99593959d90c760885
上级
50f73877
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
99 addition
and
5 deletion
+99
-5
imperative/python/test/unit/functional/test_functional.py
imperative/python/test/unit/functional/test_functional.py
+20
-0
src/opr/impl/blas.cpp
src/opr/impl/blas.cpp
+61
-0
src/opr/include/megbrain/opr/blas.h
src/opr/include/megbrain/opr/blas.h
+7
-0
src/opr/test/blas.cpp
src/opr/test/blas.cpp
+10
-3
test/src/autocheck.cpp
test/src/autocheck.cpp
+0
-1
test/src/numerical_diff.cpp
test/src/numerical_diff.cpp
+1
-1
未找到文件。
imperative/python/test/unit/functional/test_functional.py
浏览文件 @
b2827cb1
...
@@ -142,6 +142,26 @@ def test_matmul():
...
@@ -142,6 +142,26 @@ def test_matmul():
)
)
@
pytest
.
mark
.
parametrize
(
"shape_a, shape_b"
,
[((
0
,),
(
0
,)),
((
10
,
0
),
(
0
,
10
)),
((
3
,
10
,
0
),
(
3
,
0
,
10
)),],
)
@
pytest
.
mark
.
parametrize
(
"is_symbolic"
,
[
None
,
True
,
False
])
def
test_matmul_empty_tensor
(
shape_a
,
shape_b
,
is_symbolic
):
def
func
(
a
,
b
):
return
F
.
matmul
(
a
,
b
)
if
is_symbolic
is
not
None
:
func
=
jit
.
trace
(
symbolic
=
is_symbolic
)(
func
)
a
=
tensor
(
np
.
random
.
randn
(
*
shape_a
))
b
=
tensor
(
np
.
random
.
randn
(
*
shape_b
))
for
_
in
range
(
3
):
out
=
func
(
a
,
b
)
assert
np
.
all
(
out
.
numpy
()
==
0
)
if
is_symbolic
is
None
:
break
def
test_interpolate
():
def
test_interpolate
():
def
linear_interpolate
():
def
linear_interpolate
():
inp
=
tensor
(
np
.
arange
(
1
,
3
,
dtype
=
np
.
float32
).
reshape
(
1
,
1
,
2
))
inp
=
tensor
(
np
.
arange
(
1
,
3
,
dtype
=
np
.
float32
).
reshape
(
1
,
1
,
2
))
...
...
src/opr/impl/blas.cpp
浏览文件 @
b2827cb1
...
@@ -45,6 +45,7 @@ MatrixMul::MatrixMul(VarNode* a, VarNode* b, const Param& param,
...
@@ -45,6 +45,7 @@ MatrixMul::MatrixMul(VarNode* a, VarNode* b, const Param& param,
init_megdnn_opr
(
*
this
,
param
);
init_megdnn_opr
(
*
this
,
param
);
m_policy
=
policy
;
m_policy
=
policy
;
add_input
({
a
,
b
});
add_input
({
a
,
b
});
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
}
}
SymbolVar
MatrixMul
::
make
(
SymbolVar
a
,
SymbolVar
b
,
const
Param
&
param
,
SymbolVar
MatrixMul
::
make
(
SymbolVar
a
,
SymbolVar
b
,
const
Param
&
param
,
...
@@ -61,6 +62,15 @@ void MatrixMul::init_output_dtype() {
...
@@ -61,6 +62,15 @@ void MatrixMul::init_output_dtype() {
output
(
0
)
->
dtype
(
output_dtype
);
output
(
0
)
->
dtype
(
output_dtype
);
}
}
MatrixMul
::
NodeProp
*
MatrixMul
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
ret
->
add_dep_type_existing_var
(
input
(
1
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
return
ret
;
}
bool
MatrixMul
::
check_layout
(
const
TensorLayout
&
layout
,
int
transpose
)
{
bool
MatrixMul
::
check_layout
(
const
TensorLayout
&
layout
,
int
transpose
)
{
mgb_assert
(
layout
.
ndim
==
2
,
"input to MatrixMul must be 2-dim; got %s"
,
mgb_assert
(
layout
.
ndim
==
2
,
"input to MatrixMul must be 2-dim; got %s"
,
layout
.
to_string
().
c_str
());
layout
.
to_string
().
c_str
());
...
@@ -138,6 +148,17 @@ void MatrixMul::scn_do_execute() {
...
@@ -138,6 +148,17 @@ void MatrixMul::scn_do_execute() {
auto
inp0
=
input
(
0
)
->
dev_tensor
().
as_megdnn
(),
auto
inp0
=
input
(
0
)
->
dev_tensor
().
as_megdnn
(),
inp1
=
input
(
1
)
->
dev_tensor
().
as_megdnn
(),
inp1
=
input
(
1
)
->
dev_tensor
().
as_megdnn
(),
out
=
output
(
0
)
->
dev_tensor
().
as_megdnn
();
out
=
output
(
0
)
->
dev_tensor
().
as_megdnn
();
if
((
inp0
.
layout
.
is_empty
()
||
inp1
.
layout
.
is_empty
()))
{
if
(
!
out
.
layout
.
is_empty
())
{
if
(
!
m_fill_opr
)
{
m_fill_opr
=
intl
::
get_megdnn_handle
(
comp_node
())
->
create_operator
<
megdnn
::
Fill
>
();
}
m_fill_opr
->
param
()
=
0
;
m_fill_opr
->
exec
(
out
,
{});
}
return
;
}
auto
transpose
=
[](
TensorLayout
&
layout
,
bool
&
trans
)
{
auto
transpose
=
[](
TensorLayout
&
layout
,
bool
&
trans
)
{
if
(
!
check_layout
(
layout
,
0
))
{
if
(
!
check_layout
(
layout
,
0
))
{
mgb_assert
(
check_layout
(
layout
,
1
));
mgb_assert
(
check_layout
(
layout
,
1
));
...
@@ -193,6 +214,7 @@ BatchedMatrixMul::BatchedMatrixMul(VarNode* a, VarNode* b, const Param& param,
...
@@ -193,6 +214,7 @@ BatchedMatrixMul::BatchedMatrixMul(VarNode* a, VarNode* b, const Param& param,
init_megdnn_opr
(
*
this
,
param
);
init_megdnn_opr
(
*
this
,
param
);
m_policy
=
policy
;
m_policy
=
policy
;
add_input
({
a
,
b
});
add_input
({
a
,
b
});
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
}
}
SymbolVar
BatchedMatrixMul
::
make
(
SymbolVar
a
,
SymbolVar
b
,
const
Param
&
param
,
SymbolVar
BatchedMatrixMul
::
make
(
SymbolVar
a
,
SymbolVar
b
,
const
Param
&
param
,
...
@@ -229,6 +251,15 @@ void BatchedMatrixMul::init_output_dtype() {
...
@@ -229,6 +251,15 @@ void BatchedMatrixMul::init_output_dtype() {
output
(
0
)
->
dtype
(
output_dtype
);
output
(
0
)
->
dtype
(
output_dtype
);
}
}
BatchedMatrixMul
::
NodeProp
*
BatchedMatrixMul
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
ret
->
add_dep_type_existing_var
(
input
(
1
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
return
ret
;
}
bool
BatchedMatrixMul
::
check_layout
(
const
TensorLayout
&
layout
,
bool
BatchedMatrixMul
::
check_layout
(
const
TensorLayout
&
layout
,
bool
transpose
)
{
bool
transpose
)
{
int
lhs
=
(
transpose
)
?
2
:
1
,
rhs
=
(
transpose
)
?
1
:
2
;
int
lhs
=
(
transpose
)
?
2
:
1
,
rhs
=
(
transpose
)
?
1
:
2
;
...
@@ -294,6 +325,17 @@ void BatchedMatrixMul::scn_do_execute() {
...
@@ -294,6 +325,17 @@ void BatchedMatrixMul::scn_do_execute() {
auto
inp0
=
input
(
0
)
->
dev_tensor
().
as_megdnn
(),
auto
inp0
=
input
(
0
)
->
dev_tensor
().
as_megdnn
(),
inp1
=
input
(
1
)
->
dev_tensor
().
as_megdnn
(),
inp1
=
input
(
1
)
->
dev_tensor
().
as_megdnn
(),
out
=
output
(
0
)
->
dev_tensor
().
as_megdnn
();
out
=
output
(
0
)
->
dev_tensor
().
as_megdnn
();
if
((
inp0
.
layout
.
is_empty
()
||
inp1
.
layout
.
is_empty
()))
{
if
(
!
out
.
layout
.
is_empty
())
{
if
(
!
m_fill_opr
)
{
m_fill_opr
=
intl
::
get_megdnn_handle
(
comp_node
())
->
create_operator
<
megdnn
::
Fill
>
();
}
m_fill_opr
->
param
()
=
0
;
m_fill_opr
->
exec
(
out
,
{});
}
return
;
}
auto
transpose
=
[](
TensorLayout
&
layout
,
bool
&
trans
)
{
auto
transpose
=
[](
TensorLayout
&
layout
,
bool
&
trans
)
{
if
(
!
check_layout
(
layout
,
false
))
{
if
(
!
check_layout
(
layout
,
false
))
{
mgb_assert
(
check_layout
(
layout
,
true
));
mgb_assert
(
check_layout
(
layout
,
true
));
...
@@ -354,6 +396,7 @@ Dot::Dot(VarNode *opr0, VarNode *opr1, const OperatorNodeConfig &config):
...
@@ -354,6 +396,7 @@ Dot::Dot(VarNode *opr0, VarNode *opr1, const OperatorNodeConfig &config):
{
{
init_megdnn_opr
(
*
this
,
{});
init_megdnn_opr
(
*
this
,
{});
add_input
({
opr0
,
opr1
},
AddInputSortType
::
CUR_ADDED
);
add_input
({
opr0
,
opr1
},
AddInputSortType
::
CUR_ADDED
);
output
(
0
)
->
add_flag
(
VarNode
::
Flag
::
ALLOW_EMPTY_SHAPE
);
static_assert
(
std
::
is_empty
<
Param
>::
value
,
"Dot param should be empty"
);
static_assert
(
std
::
is_empty
<
Param
>::
value
,
"Dot param should be empty"
);
mgb_assert
(
opr0
->
dtype
().
category
()
!=
DTypeCategory
::
QUANTIZED
&&
mgb_assert
(
opr0
->
dtype
().
category
()
!=
DTypeCategory
::
QUANTIZED
&&
opr1
->
dtype
().
category
()
!=
DTypeCategory
::
QUANTIZED
,
opr1
->
dtype
().
category
()
!=
DTypeCategory
::
QUANTIZED
,
...
@@ -406,10 +449,28 @@ void Dot::scn_do_execute() {
...
@@ -406,10 +449,28 @@ void Dot::scn_do_execute() {
i1
.
layout
.
stride
[
0
]
=
0
;
i1
.
layout
.
stride
[
0
]
=
0
;
}
}
}
}
if
((
i0
.
layout
.
is_empty
()
||
i1
.
layout
.
is_empty
()))
{
if
(
!
m_fill_opr
)
{
m_fill_opr
=
intl
::
get_megdnn_handle
(
comp_node
())
->
create_operator
<
megdnn
::
Fill
>
();
}
m_fill_opr
->
param
()
=
0
;
m_fill_opr
->
exec
(
output
(
0
)
->
dev_tensor
().
as_megdnn
(),
{});
return
;
}
megdnn_opr
()
->
exec
(
i0
,
i1
,
output
(
0
)
->
dev_tensor
().
as_megdnn
(),
megdnn_opr
()
->
exec
(
i0
,
i1
,
output
(
0
)
->
dev_tensor
().
as_megdnn
(),
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
intl
::
get_megdnn_workspace_from_var
(
output
(
1
)));
}
}
Dot
::
NodeProp
*
Dot
::
do_make_node_prop
()
const
{
auto
ret
=
Super
::
do_make_node_prop
();
ret
->
add_dep_type_existing_var
(
input
(
0
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
ret
->
add_dep_type_existing_var
(
input
(
1
),
NodeProp
::
DepType
::
VALUE_ALLOW_EMPTY
);
return
ret
;
}
void
Dot
::
add_input_layout_constraint
()
{
void
Dot
::
add_input_layout_constraint
()
{
auto
check
=
[](
const
TensorLayout
&
ly
)
{
auto
check
=
[](
const
TensorLayout
&
ly
)
{
mgb_throw_if
(
ly
.
ndim
!=
1
,
GraphError
,
mgb_throw_if
(
ly
.
ndim
!=
1
,
GraphError
,
...
...
src/opr/include/megbrain/opr/blas.h
浏览文件 @
b2827cb1
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#include "megbrain/graph.h"
#include "megbrain/graph.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megdnn/oprs/general.h"
#include "megdnn/oprs/linalg.h"
#include "megdnn/oprs/linalg.h"
namespace
mgb
{
namespace
mgb
{
...
@@ -40,6 +41,7 @@ private:
...
@@ -40,6 +41,7 @@ private:
void
add_input_layout_constraint
()
override
;
void
add_input_layout_constraint
()
override
;
void
scn_do_execute
()
override
;
void
scn_do_execute
()
override
;
void
init_output_dtype
()
override
;
void
init_output_dtype
()
override
;
NodeProp
*
do_make_node_prop
()
const
override
;
size_t
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
size_t
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
const
TensorShapeArray
&
output_shapes
)
const
TensorShapeArray
&
output_shapes
)
const
override
;
const
override
;
...
@@ -47,6 +49,7 @@ private:
...
@@ -47,6 +49,7 @@ private:
//! store the policy of all transpose situations
//! store the policy of all transpose situations
megdnn
::
ExecutionPolicy
m_cadidate_execution_policies
[
4
];
megdnn
::
ExecutionPolicy
m_cadidate_execution_policies
[
4
];
std
::
unique_ptr
<
megdnn
::
Fill
>
m_fill_opr
;
};
};
/*!
/*!
...
@@ -70,6 +73,7 @@ private:
...
@@ -70,6 +73,7 @@ private:
void
add_input_layout_constraint
()
override
;
void
add_input_layout_constraint
()
override
;
void
init_output_dtype
()
override
;
void
init_output_dtype
()
override
;
void
scn_do_execute
()
override
;
void
scn_do_execute
()
override
;
NodeProp
*
do_make_node_prop
()
const
override
;
size_t
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
size_t
get_workspace_size_bytes
(
const
TensorShapeArray
&
input_shapes
,
const
TensorShapeArray
&
output_shapes
)
const
TensorShapeArray
&
output_shapes
)
const
override
;
const
override
;
...
@@ -77,6 +81,7 @@ private:
...
@@ -77,6 +81,7 @@ private:
static
bool
check_layout
(
const
TensorLayout
&
layout
,
bool
transpose
);
static
bool
check_layout
(
const
TensorLayout
&
layout
,
bool
transpose
);
//! store the policy of all transpose situations
//! store the policy of all transpose situations
megdnn
::
ExecutionPolicy
m_cadidate_execution_policies
[
4
];
megdnn
::
ExecutionPolicy
m_cadidate_execution_policies
[
4
];
std
::
unique_ptr
<
megdnn
::
Fill
>
m_fill_opr
;
};
};
/*!
/*!
...
@@ -101,7 +106,9 @@ MGB_DEFINE_OPR_CLASS(Dot, cg::SingleCNOperatorNodeBaseT<
...
@@ -101,7 +106,9 @@ MGB_DEFINE_OPR_CLASS(Dot, cg::SingleCNOperatorNodeBaseT<
void
add_input_layout_constraint
()
override
;
void
add_input_layout_constraint
()
override
;
void
scn_do_execute
()
override
;
void
scn_do_execute
()
override
;
void
init_output_static_infer_desc
()
override
;
void
init_output_static_infer_desc
()
override
;
NodeProp
*
do_make_node_prop
()
const
override
;
void
record_execute_deps
(
ExecDependencyArray
&
deps
)
override
;
void
record_execute_deps
(
ExecDependencyArray
&
deps
)
override
;
std
::
unique_ptr
<
megdnn
::
Fill
>
m_fill_opr
;
}
;
}
;
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1
(
MatrixInverse
);
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1
(
MatrixInverse
);
...
...
src/opr/test/blas.cpp
浏览文件 @
b2827cb1
...
@@ -94,7 +94,9 @@ void run_sgemm_test(bool transa, bool transb) {
...
@@ -94,7 +94,9 @@ void run_sgemm_test(bool transa, bool transb) {
Checker
(
make_graph
,
fwd
)
Checker
(
make_graph
,
fwd
)
.
run
({
mkx
(
4
,
6
),
mky
(
6
,
2
)},
opt
)
.
run
({
mkx
(
4
,
6
),
mky
(
6
,
2
)},
opt
)
.
run
({
mkx
(
2
,
3
),
mky
(
3
,
100
)},
opt
)
.
run
({
mkx
(
2
,
3
),
mky
(
3
,
100
)},
opt
)
.
run
({
mkx
(
20
,
3
),
mky
(
3
,
20
)},
opt
);
.
run
({
mkx
(
20
,
3
),
mky
(
3
,
20
)},
opt
)
.
run
({
mkx
(
10
,
0
),
mky
(
0
,
10
)},
opt
)
.
run
({
mkx
(
0
,
0
),
mky
(
0
,
0
)},
opt
);
}
}
#define FWD_BATCH_GEMM(dt_src, dt_dst) \
#define FWD_BATCH_GEMM(dt_src, dt_dst) \
...
@@ -143,7 +145,9 @@ void run_batched_sgemm_test(bool transa, bool transb) {
...
@@ -143,7 +145,9 @@ void run_batched_sgemm_test(bool transa, bool transb) {
Checker
(
make_graph
,
fwd
)
Checker
(
make_graph
,
fwd
)
.
run
({
mkx
(
3
,
5
,
7
),
mky
(
3
,
7
,
2
)},
opt
)
.
run
({
mkx
(
3
,
5
,
7
),
mky
(
3
,
7
,
2
)},
opt
)
.
run
({
mkx
(
64
,
1
,
2
),
mky
(
64
,
2
,
1
)},
opt
)
.
run
({
mkx
(
64
,
1
,
2
),
mky
(
64
,
2
,
1
)},
opt
)
.
run
({
mkx
(
1
,
2
,
3
),
mky
(
1
,
3
,
4
)},
opt
);
.
run
({
mkx
(
1
,
2
,
3
),
mky
(
1
,
3
,
4
)},
opt
)
.
run
({
mkx
(
3
,
0
,
2
),
mky
(
3
,
2
,
0
)},
opt
)
.
run
({
mkx
(
64
,
10
,
0
),
mky
(
64
,
0
,
10
)},
opt
);
}
}
auto
gen_fp16
=
[](
HostTensorND
&
dest
)
{
auto
gen_fp16
=
[](
HostTensorND
&
dest
)
{
...
@@ -198,6 +202,7 @@ void run_batched_hgemm_test(bool transa, bool transb) {
...
@@ -198,6 +202,7 @@ void run_batched_hgemm_test(bool transa, bool transb) {
checker
.
run
({
mkx
(
3
,
5
,
7
),
mky
(
3
,
7
,
2
)},
opt
)
checker
.
run
({
mkx
(
3
,
5
,
7
),
mky
(
3
,
7
,
2
)},
opt
)
.
run
({
mkx
(
64
,
1
,
2
),
mky
(
64
,
2
,
1
)},
opt
)
.
run
({
mkx
(
64
,
1
,
2
),
mky
(
64
,
2
,
1
)},
opt
)
.
run
({
mkx
(
64
,
10
,
0
),
mky
(
64
,
0
,
10
)},
opt
)
.
run
({
mkx
(
1
,
2
,
3
),
mky
(
1
,
3
,
4
)},
opt
);
.
run
({
mkx
(
1
,
2
,
3
),
mky
(
1
,
3
,
4
)},
opt
);
}
}
...
@@ -236,6 +241,7 @@ void run_batched_igemm_test(bool transa, bool transb) {
...
@@ -236,6 +241,7 @@ void run_batched_igemm_test(bool transa, bool transb) {
checker
.
run
({
mkx
(
3
,
5
,
7
),
mky
(
3
,
7
,
2
)},
opt
)
checker
.
run
({
mkx
(
3
,
5
,
7
),
mky
(
3
,
7
,
2
)},
opt
)
.
run
({
mkx
(
64
,
1
,
2
),
mky
(
64
,
2
,
1
)},
opt
)
.
run
({
mkx
(
64
,
1
,
2
),
mky
(
64
,
2
,
1
)},
opt
)
.
run
({
mkx
(
64
,
10
,
0
),
mky
(
64
,
0
,
10
)},
opt
)
.
run
({
mkx
(
1
,
2
,
3
),
mky
(
1
,
3
,
4
)},
opt
);
.
run
({
mkx
(
1
,
2
,
3
),
mky
(
1
,
3
,
4
)},
opt
);
}
}
...
@@ -650,7 +656,8 @@ TEST(TestOprBlas, Dot) {
...
@@ -650,7 +656,8 @@ TEST(TestOprBlas, Dot) {
.
run
({
TensorShape
{
15
},
TensorShape
{
1
}})
.
run
({
TensorShape
{
15
},
TensorShape
{
1
}})
.
run
({
TensorShape
{
1
},
TensorShape
{
16
}})
.
run
({
TensorShape
{
1
},
TensorShape
{
16
}})
.
run
({
TensorShape
{
23
},
TensorShape
{
23
}})
.
run
({
TensorShape
{
23
},
TensorShape
{
23
}})
.
run
({
TensorShape
{
1000
},
TensorShape
{
1000
}});
.
run
({
TensorShape
{
1000
},
TensorShape
{
1000
}})
.
run
({
TensorShape
{
0
},
TensorShape
{
0
}});
}
}
TEST
(
TestOprBlas
,
TransMatMul
)
{
TEST
(
TestOprBlas
,
TransMatMul
)
{
...
...
test/src/autocheck.cpp
浏览文件 @
b2827cb1
...
@@ -250,7 +250,6 @@ DEF_IMPL(void)::do_run(const ShapeInpArray& shapes, const RunOptions& opt) {
...
@@ -250,7 +250,6 @@ DEF_IMPL(void)::do_run(const ShapeInpArray& shapes, const RunOptions& opt) {
for
(
size_t
i
=
0
;
i
<
nr_out
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
nr_out
;
++
i
)
{
if
(
m_outputs_allow_grad
[
i
])
{
if
(
m_outputs_allow_grad
[
i
])
{
auto
nr
=
m_outputs_truth
[
i
].
shape
().
total_nr_elems
();
auto
nr
=
m_outputs_truth
[
i
].
shape
().
total_nr_elems
();
mgb_assert
(
nr
,
"got empty output"
);
if
(
opt
.
cont_loss_p
)
{
if
(
opt
.
cont_loss_p
)
{
m_loss_p
[
i
]
->
resize
({
nr
});
m_loss_p
[
i
]
->
resize
({
nr
});
auto
ptr
=
m_loss_p
[
i
]
->
template
ptr
<
float
>();
auto
ptr
=
m_loss_p
[
i
]
->
template
ptr
<
float
>();
...
...
test/src/numerical_diff.cpp
浏览文件 @
b2827cb1
...
@@ -36,7 +36,7 @@ std::vector<HostTensorND> mgb::numerical_diff_pt2(
...
@@ -36,7 +36,7 @@ std::vector<HostTensorND> mgb::numerical_diff_pt2(
resize
(
cur_inp
->
shape
());
resize
(
cur_inp
->
shape
());
auto
dptr
=
dest
.
ptr
<
float
>
();
auto
dptr
=
dest
.
ptr
<
float
>
();
mgb_assert
(
cur_inp
->
layout
().
is_contiguous
());
mgb_assert
(
cur_inp
->
layout
().
is_contiguous
()
||
cur_inp
->
layout
().
is_empty
()
);
auto
cur_inp_ptr
=
cur_inp
->
ptr
<
float
>
();
auto
cur_inp_ptr
=
cur_inp
->
ptr
<
float
>
();
mgb
::
RealTimer
timer
;
mgb
::
RealTimer
timer
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录