Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3d3666b6
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
407
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3d3666b6
编写于
3年前
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test(dnn/bn): add compatible configs for NHWC BN
GitOrigin-RevId: ac757ca307f53ee2af9af8e3943a1d7776fa6c37
上级
b3e54ead
master
HuaHua404-patch-1
HuaHua404-patch-2
HuaHua404-patch-3
HuaHua404-patch-4
add-tools
dev-support-lite-fork-debug-mode
docstring-reshape
release-1.10
release-1.11
release-1.11.1
release-1.12.0
release-1.12.1
release-1.12.2
release-1.12.3
release-1.12.4
release-1.12.5
release-1.13.0
release-1.13.1
release-1.7
release-1.8
release-1.9
revert-410-docstring-zeros
revert-411-add-tools
test-try-import
tmp-test
try-import
v1.13.1
v1.13.0
v1.12.4
v1.12.3
v1.12.2
v1.12.1
v1.12.0
v1.11.1
v1.11.0
v1.10.0
v1.9.1
v1.9.0
v1.8.2
v1.8.1
v1.8.1.m1
v1.8.0
v1.7.2.m1
v1.7.1.m1
v1.7.0
v1.7.0.m1
无相关合并请求
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
89 addition
and
26 deletion
+89
-26
dnn/test/common/bn.h
dnn/test/common/bn.h
+13
-1
dnn/test/common/checker.cpp
dnn/test/common/checker.cpp
+1
-1
dnn/test/common/checker.h
dnn/test/common/checker.h
+5
-0
dnn/test/common/deduce_layout_proxy.h
dnn/test/common/deduce_layout_proxy.h
+9
-0
dnn/test/common/exec_proxy.h
dnn/test/common/exec_proxy.h
+17
-0
dnn/test/common/rng.cpp
dnn/test/common/rng.cpp
+4
-0
dnn/test/cuda/bn.cpp
dnn/test/cuda/bn.cpp
+33
-20
dnn/test/rocm/bn.cpp
dnn/test/rocm/bn.cpp
+7
-4
未找到文件。
dnn/test/common/bn.h
浏览文件 @
3d3666b6
...
...
@@ -53,6 +53,18 @@ std::vector<TestArg> get_args() {
TensorShape
{
1
,
3
,
1
,
1
},
dtype
::
Float16
());
}
// case 3: 1 x 1 x 1 x C
for
(
size_t
i
=
4
;
i
<
257
;
i
*=
4
)
{
param
::
BN
param
;
param
.
fwd_mode
=
param
::
BN
::
FwdMode
::
TRAINING
;
param
.
param_dim
=
param
::
BN
::
ParamDim
::
DIM_111C
;
args
.
emplace_back
(
param
,
TensorShape
{
3
,
i
,
i
,
3
},
TensorShape
{
1
,
1
,
1
,
3
},
dtype
::
Float32
());
args
.
emplace_back
(
param
,
TensorShape
{
3
,
i
,
i
,
3
},
TensorShape
{
1
,
1
,
1
,
3
},
dtype
::
Float16
());
}
return
args
;
}
...
...
@@ -60,4 +72,4 @@ std::vector<TestArg> get_args() {
}
// namespace test
}
// namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
// vim: syntax=cpp.doxygen
This diff is collapsed.
Click to expand it.
dnn/test/common/checker.cpp
浏览文件 @
3d3666b6
...
...
@@ -419,7 +419,7 @@ void CheckerHelper::copy_tensors_from_device(const TensorValueArray& dest,
void
CheckerHelper
::
check_tensors
(
const
TensorValueArray
&
expected
,
const
TensorValueArray
&
computed
)
{
for
(
size_t
i
=
0
;
i
<
expected
.
size
();
++
i
)
{
if
(
expected
[
i
].
layout
.
ndim
==
0
)
if
(
expected
[
i
].
layout
.
ndim
==
0
||
m_bypass
.
find
(
i
)
!=
m_bypass
.
end
()
)
continue
;
if
(
m_allow_invalid_check
)
{
MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID
(
...
...
This diff is collapsed.
Click to expand it.
dnn/test/common/checker.h
浏览文件 @
3d3666b6
...
...
@@ -69,6 +69,7 @@ protected:
std
::
unordered_map
<
size_t
,
RNG
*>
m_rng
;
std
::
unordered_map
<
size_t
,
DType
>
m_dtype
;
std
::
unordered_map
<
size_t
,
TensorFormat
>
m_fmt
;
std
::
set
<
size_t
>
m_bypass
;
float_t
m_epsilon
=
1e-3
,
m_max_avg_error
=
1e-3
,
m_max_avg_biased_error
=
1e-3
;
float_t
m_perf_check_threshold
=
-
1
;
...
...
@@ -184,6 +185,10 @@ public:
m_rng
[
idx
]
=
rng
;
return
*
this
;
}
Checker
&
set_bypass
(
size_t
idx
)
{
m_bypass
.
insert
(
idx
);
return
*
this
;
}
//! max error of a single element
Checker
&
set_epsilon
(
dt_float32
epsilon
)
{
m_epsilon
=
epsilon
;
...
...
This diff is collapsed.
Click to expand it.
dnn/test/common/deduce_layout_proxy.h
浏览文件 @
3d3666b6
...
...
@@ -82,6 +82,15 @@ struct DeduceLayoutProxy<Opr, 8, true> {
}
};
template
<
typename
Opr
>
struct
DeduceLayoutProxy
<
Opr
,
9
,
true
>
{
static
void
deduce_layout
(
Opr
*
opr
,
TensorLayoutArray
&
layouts
)
{
megdnn_assert
(
layouts
.
size
()
==
9
);
opr
->
deduce_layout
(
layouts
[
0
],
layouts
[
1
],
layouts
[
2
],
layouts
[
3
],
layouts
[
4
],
layouts
[
5
],
layouts
[
6
],
layouts
[
7
],
layouts
[
8
]);
}
};
}
// namespace test
}
// namespace megdnn
...
...
This diff is collapsed.
Click to expand it.
dnn/test/common/exec_proxy.h
浏览文件 @
3d3666b6
...
...
@@ -22,6 +22,23 @@ namespace test {
template
<
typename
Opr
,
size_t
Arity
,
bool
has_workspace
>
struct
ExecProxy
;
template
<
typename
Opr
>
struct
ExecProxy
<
Opr
,
9
,
true
>
{
WorkspaceWrapper
W
;
void
exec
(
Opr
*
opr
,
const
TensorNDArray
&
tensors
)
{
if
(
!
W
.
valid
())
{
W
=
WorkspaceWrapper
(
opr
->
handle
(),
0
);
}
W
.
update
(
opr
->
get_workspace_in_bytes
(
tensors
[
0
].
layout
,
tensors
[
1
].
layout
,
tensors
[
2
].
layout
,
tensors
[
3
].
layout
,
tensors
[
4
].
layout
,
tensors
[
5
].
layout
,
tensors
[
6
].
layout
,
tensors
[
7
].
layout
,
tensors
[
8
].
layout
));
opr
->
exec
(
tensors
[
0
],
tensors
[
1
],
tensors
[
2
],
tensors
[
3
],
tensors
[
4
],
tensors
[
5
],
tensors
[
6
],
tensors
[
7
],
tensors
[
8
],
W
.
workspace
());
}
};
template
<
typename
Opr
>
struct
ExecProxy
<
Opr
,
8
,
true
>
{
WorkspaceWrapper
W
;
...
...
This diff is collapsed.
Click to expand it.
dnn/test/common/rng.cpp
浏览文件 @
3d3666b6
...
...
@@ -211,6 +211,10 @@ void IIDRNG::gen(const TensorND& tensor) {
}
return
;
}
if
(
tensor
.
layout
.
dtype
.
enumv
()
==
DTypeEnum
::
Byte
)
{
memset
(
tensor
.
raw_ptr
,
0
,
tensor
.
layout
.
access_bytes
());
return
;
}
megdnn_assert
(
0
,
"IIDRNG does not know how to generate value for DType %s"
,
tensor
.
layout
.
dtype
.
name
());
}
...
...
This diff is collapsed.
Click to expand it.
dnn/test/cuda/bn.cpp
浏览文件 @
3d3666b6
...
...
@@ -6,10 +6,13 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/cuda/fixture.h"
#include "src/cuda/batch_normalization/opr_impl.h"
#include "src/cuda/utils.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "test/common/bn.h"
...
...
@@ -21,15 +24,26 @@
namespace
megdnn
{
namespace
test
{
TEST_F
(
CUDA
,
BN_FORWARD
)
{
TEST_F
(
CUDA
,
BN_FORWARD
_BACKWARD
)
{
using
namespace
batch_normalization
;
using
cuda
::
cudnn_handle
;
using
cuda
::
batch_normalization
::
BNTensorDescHolder
;
using
cuda
::
batch_normalization
::
get_reserve_size
;
std
::
vector
<
TestArg
>
args
=
get_args
();
Checker
<
BNForward
>
checker
(
handle_cuda
());
Checker
<
BNBackward
>
checker_bwd
(
handle_cuda
());
for
(
auto
&&
arg
:
args
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
auto
tensor_desc
=
BNTensorDescHolder
({
arg
.
src
,
arg
.
dtype
},
arg
.
param
.
param_dim
,
arg
.
param
.
fwd_mode
);
auto
reserve
=
get_reserve_size
(
cudnn_handle
(
handle_cuda
()),
tensor_desc
);
// Forward
for
(
int
i
=
0
;
i
<
9
;
++
i
)
{
checker
.
set_dtype
(
i
,
dtype
::
Float32
());
}
checker
.
set_dtype
(
0
,
arg
.
dtype
);
checker
.
set_dtype
(
7
,
dtype
::
Byte
());
checker
.
set_dtype
(
8
,
arg
.
dtype
);
checker
.
set_bypass
(
7
);
checker
.
set_epsilon
(
1e-3
).
set_param
(
arg
.
param
);
for
(
bool
need_statistic
:
{
false
,
true
})
checker
.
exec
({
...
...
@@ -40,27 +54,26 @@ TEST_F(CUDA, BN_FORWARD) {
:
TensorShape
({
0
}),
// mean
need_statistic
?
arg
.
param_shape
:
TensorShape
({
0
}),
// variance
arg
.
param_shape
,
// batch_mean
arg
.
param_shape
,
// batch_inv_variance
{}
// dst
arg
.
param_shape
,
// batch_mean
arg
.
param_shape
,
// batch_inv_variance
{
reserve
},
// reserve
arg
.
src
// dst
});
}
}
TEST_F
(
CUDA
,
BN_BACKWARD
)
{
using
namespace
batch_normalization
;
std
::
vector
<
TestArg
>
args
=
get_args
();
Checker
<
BNBackward
>
checker
(
handle_cuda
());
for
(
auto
&&
arg
:
args
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
checker
.
set_dtype
(
i
,
dtype
::
Float32
());
// Backward
for
(
int
i
=
0
;
i
<
9
;
++
i
)
{
checker_bwd
.
set_dtype
(
i
,
dtype
::
Float32
());
}
checker
.
set_dtype
(
0
,
arg
.
dtype
)
// x
.
set_dtype
(
1
,
arg
.
dtype
)
// dy
.
set_dtype
(
7
,
arg
.
dtype
);
// dx
checker
.
set_epsilon
(
1e-3
).
set_param
(
arg
.
param
).
exec
(
checker_bwd
.
set_dtype
(
0
,
arg
.
dtype
)
// x
.
set_dtype
(
1
,
arg
.
dtype
)
// dy
.
set_dtype
(
5
,
dtype
::
Byte
())
// reserve
.
set_dtype
(
8
,
arg
.
dtype
)
// dx
.
set_bypass
(
5
);
checker_bwd
.
set_epsilon
(
1e-3
).
set_param
(
arg
.
param
).
exec
(
{
arg
.
src
,
arg
.
src
,
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
src
});
arg
.
param_shape
,
{
reserve
},
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
src
});
}
}
...
...
This diff is collapsed.
Click to expand it.
dnn/test/rocm/bn.cpp
浏览文件 @
3d3666b6
...
...
@@ -31,6 +31,7 @@ TEST_F(ROCM, BN_FORWARD) {
checker
.
set_dtype
(
i
,
dtype
::
Float32
());
}
checker
.
set_dtype
(
0
,
arg
.
dtype
);
checker
.
set_dtype
(
8
,
arg
.
dtype
);
checker
.
set_epsilon
(
1e-3
).
set_param
(
arg
.
param
);
for
(
bool
need_statistic
:
{
false
,
true
})
checker
.
exec
({
...
...
@@ -43,7 +44,8 @@ TEST_F(ROCM, BN_FORWARD) {
:
TensorShape
({
0
}),
// variance
arg
.
param_shape
,
// batch_mean
arg
.
param_shape
,
// batch_inv_variance
{}
// dst
{
0
},
// reserve
arg
.
src
// dst
});
}
}
...
...
@@ -53,15 +55,16 @@ TEST_F(ROCM, BN_BACKWARD) {
std
::
vector
<
TestArg
>
args
=
get_args
();
Checker
<
BNBackward
>
checker
(
handle_rocm
());
for
(
auto
&&
arg
:
args
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
for
(
int
i
=
0
;
i
<
9
;
++
i
)
{
checker
.
set_dtype
(
i
,
dtype
::
Float32
());
}
checker
.
set_dtype
(
0
,
arg
.
dtype
)
// x
.
set_dtype
(
1
,
arg
.
dtype
)
// dy
.
set_dtype
(
7
,
arg
.
dtype
);
// dx
.
set_dtype
(
8
,
arg
.
dtype
);
// dx
checker
.
set_epsilon
(
1e-3
).
set_param
(
arg
.
param
).
exec
(
{
arg
.
src
,
arg
.
src
,
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
src
});
arg
.
param_shape
,
{
0
},
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
src
});
}
}
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部