Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
89303cd8
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
89303cd8
编写于
11月 04, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(megdnn/rocm): add bn for rocm backend
GitOrigin-RevId: 8bd49599b28847e51735ff57293bc16ff0033924
上级
7cd846a5
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
281 addition
and
10 deletion
+281
-10
cmake/rocm.cmake
cmake/rocm.cmake
+11
-10
dnn/src/rocm/batch_normalization/opr_impl.cpp
dnn/src/rocm/batch_normalization/opr_impl.cpp
+116
-0
dnn/src/rocm/batch_normalization/opr_impl.h
dnn/src/rocm/batch_normalization/opr_impl.h
+80
-0
dnn/src/rocm/handle.cpp
dnn/src/rocm/handle.cpp
+3
-0
dnn/test/rocm/bn.cpp
dnn/test/rocm/bn.cpp
+71
-0
未找到文件。
cmake/rocm.cmake
浏览文件 @
89303cd8
...
@@ -17,10 +17,11 @@ string(REPLACE "." ";" HIP_VERSION_LIST ${HIP_VERSION})
...
@@ -17,10 +17,11 @@ string(REPLACE "." ";" HIP_VERSION_LIST ${HIP_VERSION})
list
(
GET HIP_VERSION_LIST 0 HIP_VERSION_MAJOR
)
list
(
GET HIP_VERSION_LIST 0 HIP_VERSION_MAJOR
)
list
(
GET HIP_VERSION_LIST 1 HIP_VERSION_MINOR
)
list
(
GET HIP_VERSION_LIST 1 HIP_VERSION_MINOR
)
if
(
NOT
${
HIP_VERSION_MAJOR
}
STREQUAL
"3"
)
if
(
NOT
${
HIP_VERSION_MAJOR
}
STREQUAL
"3"
)
message
(
FATAL_ERROR
"ROCM version needed 3.7.Please update ROCM."
)
message
(
FATAL_ERROR
"ROCM version needed 3.x, Please update ROCM."
)
endif
()
else
()
if
(
NOT
${
HIP_VERSION_MINOR
}
STREQUAL
"7"
)
if
(
${
HIP_VERSION_MINOR
}
LESS
"7"
)
message
(
FATAL_ERROR
"ROCM version needed 3.7.Please update ROCM."
)
message
(
WARNING
"ROCM version 3.x which x(got
${
HIP_VERSION_MINOR
}
) greater equal 7 is prefered."
)
endif
()
endif
()
endif
()
set
(
MGE_ROCM_LIBS OpenCL amdhip64 MIOpen rocblas rocrand
)
set
(
MGE_ROCM_LIBS OpenCL amdhip64 MIOpen rocblas rocrand
)
...
@@ -37,7 +38,7 @@ find_path(MIOPEN_LIBRARY_DIR
...
@@ -37,7 +38,7 @@ find_path(MIOPEN_LIBRARY_DIR
DOC
"Path to MIOPEN library directory."
)
DOC
"Path to MIOPEN library directory."
)
if
(
MIOPEN_LIBRARY_DIR STREQUAL
"MIOPEN_LIBRARY_DIR-NOTFOUND"
)
if
(
MIOPEN_LIBRARY_DIR STREQUAL
"MIOPEN_LIBRARY_DIR-NOTFOUND"
)
message
(
FATAL_ERROR
"Can not find MIOPEN Library"
)
message
(
FATAL_ERROR
"Can not find MIOPEN Library"
)
endif
()
endif
()
get_filename_component
(
__found_miopen_include
${
HIP_ROOT_DIR
}
/../miopen/include REALPATH
)
get_filename_component
(
__found_miopen_include
${
HIP_ROOT_DIR
}
/../miopen/include REALPATH
)
...
@@ -48,7 +49,7 @@ find_path(MIOPEN_INCLUDE_DIR
...
@@ -48,7 +49,7 @@ find_path(MIOPEN_INCLUDE_DIR
DOC
"Path to MIOPEN include directory."
)
DOC
"Path to MIOPEN include directory."
)
if
(
MIOPEN_INCLUDE_DIR STREQUAL
"MIOPEN_INCLUDE_DIR-NOTFOUND"
)
if
(
MIOPEN_INCLUDE_DIR STREQUAL
"MIOPEN_INCLUDE_DIR-NOTFOUND"
)
message
(
FATAL_ERROR
"Can not find MIOEPN INCLUDE"
)
message
(
FATAL_ERROR
"Can not find MIOEPN INCLUDE"
)
endif
()
endif
()
#rocblas
#rocblas
...
@@ -60,7 +61,7 @@ find_path(ROCBLAS_LIBRARY_DIR
...
@@ -60,7 +61,7 @@ find_path(ROCBLAS_LIBRARY_DIR
DOC
"Path to ROCBLAS library directory."
)
DOC
"Path to ROCBLAS library directory."
)
if
(
ROCBLAS_LIBRARY_DIR STREQUAL
"ROCBLAS_LIBRARY_DIR-NOTFOUND"
)
if
(
ROCBLAS_LIBRARY_DIR STREQUAL
"ROCBLAS_LIBRARY_DIR-NOTFOUND"
)
message
(
FATAL_ERROR
"Can not find ROCBLAS Library"
)
message
(
FATAL_ERROR
"Can not find ROCBLAS Library"
)
endif
()
endif
()
get_filename_component
(
__found_rocblas_include
${
HIP_ROOT_DIR
}
/../rocblas/include REALPATH
)
get_filename_component
(
__found_rocblas_include
${
HIP_ROOT_DIR
}
/../rocblas/include REALPATH
)
...
@@ -71,7 +72,7 @@ find_path(ROCBLAS_INCLUDE_DIR
...
@@ -71,7 +72,7 @@ find_path(ROCBLAS_INCLUDE_DIR
DOC
"Path to ROCBLAS include directory."
)
DOC
"Path to ROCBLAS include directory."
)
if
(
ROCBLAS_INCLUDE_DIR STREQUAL
"ROCBLAS_INCLUDE_DIR-NOTFOUND"
)
if
(
ROCBLAS_INCLUDE_DIR STREQUAL
"ROCBLAS_INCLUDE_DIR-NOTFOUND"
)
message
(
FATAL_ERROR
"Can not find ROCBLAS INCLUDE"
)
message
(
FATAL_ERROR
"Can not find ROCBLAS INCLUDE"
)
endif
()
endif
()
#rocrand
#rocrand
...
@@ -83,7 +84,7 @@ find_path(ROCRAND_LIBRARY_DIR
...
@@ -83,7 +84,7 @@ find_path(ROCRAND_LIBRARY_DIR
DOC
"Path to ROCRAND library directory."
)
DOC
"Path to ROCRAND library directory."
)
if
(
ROCRAND_LIBRARY_DIR STREQUAL
"ROCRAND_LIBRARY_DIR-NOTFOUND"
)
if
(
ROCRAND_LIBRARY_DIR STREQUAL
"ROCRAND_LIBRARY_DIR-NOTFOUND"
)
message
(
FATAL_ERROR
"Can not find ROCRAND Library"
)
message
(
FATAL_ERROR
"Can not find ROCRAND Library"
)
endif
()
endif
()
get_filename_component
(
__found_rocrand_include
${
HIP_ROOT_DIR
}
/../rocrand/include REALPATH
)
get_filename_component
(
__found_rocrand_include
${
HIP_ROOT_DIR
}
/../rocrand/include REALPATH
)
...
@@ -94,7 +95,7 @@ find_path(ROCRAND_INCLUDE_DIR
...
@@ -94,7 +95,7 @@ find_path(ROCRAND_INCLUDE_DIR
DOC
"Path to ROCRAND include directory."
)
DOC
"Path to ROCRAND include directory."
)
if
(
ROCRAND_INCLUDE_DIR STREQUAL
"ROCRAND_INCLUDE_DIR-NOTFOUND"
)
if
(
ROCRAND_INCLUDE_DIR STREQUAL
"ROCRAND_INCLUDE_DIR-NOTFOUND"
)
message
(
FATAL_ERROR
"Can not find ROCRAND INCLUDE"
)
message
(
FATAL_ERROR
"Can not find ROCRAND INCLUDE"
)
endif
()
endif
()
dnn/src/rocm/batch_normalization/opr_impl.cpp
0 → 100644
浏览文件 @
89303cd8
/**
* \file dnn/src/rocm/batch_normalization/opr_impl.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./opr_impl.h"
#include "src/rocm/utils.h"
namespace
megdnn
{
namespace
rocm
{
namespace
batch_normalization
{
void
BNTensorDescHolder
::
setup
(
const
TensorLayout
&
x
,
const
ParamDim
&
param_dim
)
{
TensorShape
xy_shape
(
x
);
switch
(
param_dim
)
{
case
ParamDim
::
DIM_11HW
:
// xy: N, C, H, W --> (N*C), 1, H, W
xy_shape
.
shape
[
0
]
=
xy_shape
.
shape
[
0
]
*
xy_shape
.
shape
[
1
];
xy_shape
.
shape
[
1
]
=
1
;
bn_mode
=
miopenBNPerActivation
;
break
;
case
ParamDim
::
DIM_1CHW
:
bn_mode
=
miopenBNPerActivation
;
break
;
case
ParamDim
::
DIM_1C11
:
bn_mode
=
miopenBNSpatial
;
break
;
default:
megdnn_throw
(
megdnn_mangle
(
"Unknown param dim type of batch normalization."
));
}
xy_desc
.
set
(
TensorLayout
(
xy_shape
,
x
.
dtype
));
param_desc
.
set
(
xy_desc
.
desc
,
bn_mode
);
}
}
// namespace batch_normalization
void
BNForwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_in
bn_bias
,
_megdnn_tensor_out
mean
,
_megdnn_tensor_out
variance
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
check_exec
(
src
.
layout
,
bn_scale
.
layout
,
bn_bias
.
layout
,
mean
.
layout
,
variance
.
layout
,
batch_mean
.
layout
,
batch_inv_variance
.
layout
,
dst
.
layout
,
workspace
.
size
);
auto
handle
=
concrete_handle
(
this
->
handle
())
->
miopen_handle
();
m_tensor_desc
.
setup
(
src
.
layout
,
m_param
.
param_dim
);
float
alpha
=
1.0
f
,
beta
=
0.0
f
;
switch
(
m_param
.
fwd_mode
)
{
case
param
::
BN
::
FwdMode
::
TRAINING
:
miopen_check
(
miopenBatchNormalizationForwardTraining
(
handle
,
m_tensor_desc
.
bn_mode
,
&
alpha
,
&
beta
,
m_tensor_desc
.
xy_desc
.
desc
,
// xDesc
src
.
raw_ptr
,
// x
m_tensor_desc
.
xy_desc
.
desc
,
// yDesc
dst
.
raw_ptr
,
// y
m_tensor_desc
.
param_desc
.
desc
,
// bnScaleBiasMeanVarDesc
bn_scale
.
raw_ptr
,
bn_bias
.
raw_ptr
,
m_param
.
avg_factor
,
mean
.
raw_ptr
,
variance
.
raw_ptr
,
m_param
.
epsilon
,
batch_mean
.
raw_ptr
,
batch_inv_variance
.
raw_ptr
));
break
;
case
param
::
BN
::
FwdMode
::
INFERENCE
:
miopen_check
(
miopenBatchNormalizationForwardInference
(
handle
,
m_tensor_desc
.
bn_mode
,
&
alpha
,
&
beta
,
m_tensor_desc
.
xy_desc
.
desc
,
src
.
raw_ptr
,
m_tensor_desc
.
xy_desc
.
desc
,
dst
.
raw_ptr
,
m_tensor_desc
.
param_desc
.
desc
,
bn_scale
.
raw_ptr
,
bn_bias
.
raw_ptr
,
mean
.
raw_ptr
,
variance
.
raw_ptr
,
m_param
.
epsilon
));
break
;
default:
megdnn_throw
(
megdnn_mangle
(
"Unknown forward mode type of batch normalization."
));
}
}
void
BNBackwardImpl
::
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_out
d_bn_scale
,
_megdnn_tensor_out
d_bn_bias
,
_megdnn_tensor_out
dx
,
_megdnn_workspace
workspace
)
{
check_exec
(
x
.
layout
,
dy
.
layout
,
saved_batch_mean
.
layout
,
saved_batch_inv_variance
.
layout
,
bn_scale
.
layout
,
d_bn_scale
.
layout
,
d_bn_bias
.
layout
,
dx
.
layout
,
workspace
.
size
);
auto
handle
=
concrete_handle
(
this
->
handle
())
->
miopen_handle
();
m_tensor_desc
.
setup
(
x
.
layout
,
m_param
.
param_dim
);
float
alpha
=
1.0
,
beta
=
0.0
;
miopen_check
(
miopenBatchNormalizationBackward
(
handle
,
m_tensor_desc
.
bn_mode
,
&
alpha
,
&
beta
,
&
alpha
,
&
beta
,
m_tensor_desc
.
xy_desc
.
desc
,
x
.
raw_ptr
,
m_tensor_desc
.
xy_desc
.
desc
,
dy
.
raw_ptr
,
m_tensor_desc
.
xy_desc
.
desc
,
dx
.
raw_ptr
,
m_tensor_desc
.
param_desc
.
desc
,
bn_scale
.
raw_ptr
,
d_bn_scale
.
raw_ptr
,
d_bn_bias
.
raw_ptr
,
m_param
.
epsilon
,
saved_batch_mean
.
raw_ptr
,
saved_batch_inv_variance
.
raw_ptr
));
}
}
// namespace rocm
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/rocm/batch_normalization/opr_impl.h
0 → 100644
浏览文件 @
89303cd8
/**
* \file dnn/src/rocm/batch_normalization/opr_impl.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/oprs.h"
#include "src/rocm/miopen_wrapper.h"
namespace
megdnn
{
namespace
rocm
{
namespace
batch_normalization
{
struct
BNTensorDescHolder
{
using
ParamDim
=
param
::
BN
::
ParamDim
;
TensorDesc
xy_desc
;
BNParamDesc
param_desc
;
miopenBatchNormMode_t
bn_mode
;
void
setup
(
const
TensorLayout
&
x
,
const
ParamDim
&
param_dim
);
};
}
// namespace batch_normalization
class
BNForwardImpl
final
:
public
BNForward
{
public:
using
BNForward
::
BNForward
;
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_in
bn_bias
,
_megdnn_tensor_out
mean
,
_megdnn_tensor_out
variance
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
private:
batch_normalization
::
BNTensorDescHolder
m_tensor_desc
;
};
class
BNBackwardImpl
final
:
public
BNBackward
{
public:
using
BNBackward
::
BNBackward
;
void
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_out
d_bn_scale
,
_megdnn_tensor_out
d_bn_bias
,
_megdnn_tensor_out
dx
,
_megdnn_workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
return
0
;
}
private:
batch_normalization
::
BNTensorDescHolder
m_tensor_desc
;
};
}
// namespace rocm
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/rocm/handle.cpp
浏览文件 @
89303cd8
...
@@ -35,6 +35,7 @@
...
@@ -35,6 +35,7 @@
#include "src/rocm/linspace/opr_impl.h"
#include "src/rocm/linspace/opr_impl.h"
#include "src/rocm/argmxx/opr_impl.h"
#include "src/rocm/argmxx/opr_impl.h"
#include "src/rocm/sleep/opr_impl.h"
#include "src/rocm/sleep/opr_impl.h"
#include "src/rocm/batch_normalization/opr_impl.h"
#include <miopen/version.h>
#include <miopen/version.h>
#include <hip/hip_version.h>
#include <hip/hip_version.h>
...
@@ -171,6 +172,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace);
...
@@ -171,6 +172,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(Linspace);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
ArgmaxForward
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
ArgmaxForward
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
ArgminForward
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
ArgminForward
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
SleepForward
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
SleepForward
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
BNForward
);
MEGDNN_SPECIALIZE_CREATE_OPERATOR
(
BNBackward
);
#pragma GCC diagnostic push
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas"
#pragma GCC diagnostic ignored "-Wpragmas"
...
...
dnn/test/rocm/bn.cpp
0 → 100644
浏览文件 @
89303cd8
/**
* \file dnn/test/rocm/bn.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "test/rocm/fixture.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs.h"
#include "test/common/bn.h"
#include "test/common/checker.h"
#include "test/common/rng.h"
#include "test/common/tensor.h"
#include "test/common/workspace_wrapper.h"
namespace
megdnn
{
namespace
test
{
TEST_F
(
ROCM
,
BN_FORWARD
)
{
using
namespace
batch_normalization
;
std
::
vector
<
TestArg
>
args
=
get_args
();
Checker
<
BNForward
>
checker
(
handle_rocm
());
for
(
auto
&&
arg
:
args
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
checker
.
set_dtype
(
i
,
dtype
::
Float32
());
}
checker
.
set_dtype
(
0
,
arg
.
dtype
);
checker
.
set_epsilon
(
1e-3
).
set_param
(
arg
.
param
);
for
(
bool
need_statistic
:
{
false
,
true
})
checker
.
exec
({
arg
.
src
,
arg
.
param_shape
,
// bn_scale
arg
.
param_shape
,
// bn_bias
need_statistic
?
arg
.
param_shape
:
TensorShape
({
0
}),
// mean
need_statistic
?
arg
.
param_shape
:
TensorShape
({
0
}),
// variance
arg
.
param_shape
,
// batch_mean
arg
.
param_shape
,
// batch_inv_variance
{}
// dst
});
}
}
TEST_F
(
ROCM
,
BN_BACKWARD
)
{
using
namespace
batch_normalization
;
std
::
vector
<
TestArg
>
args
=
get_args
();
Checker
<
BNBackward
>
checker
(
handle_rocm
());
for
(
auto
&&
arg
:
args
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
checker
.
set_dtype
(
i
,
dtype
::
Float32
());
}
checker
.
set_dtype
(
0
,
arg
.
dtype
)
// x
.
set_dtype
(
1
,
arg
.
dtype
)
// dy
.
set_dtype
(
7
,
arg
.
dtype
);
// dx
checker
.
set_epsilon
(
1e-3
).
set_param
(
arg
.
param
).
exec
(
{
arg
.
src
,
arg
.
src
,
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
param_shape
,
arg
.
src
});
}
}
}
// namespace test
}
// namespace megdnn
// vim: syntax=cpp.doxygen
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录