Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
b3e54ead
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
b3e54ead
编写于
8月 17, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/bn): use new cudnn BN kernel to support NHWC
GitOrigin-RevId: 9d80f2009d0496f532b267fcf841785b74c0b50c
上级
6b863cc5
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
268 addition
and
123 deletion
+268
-123
dnn/include/megdnn/oprs/nn.h
dnn/include/megdnn/oprs/nn.h
+29
-17
dnn/scripts/opr_param_defs.py
dnn/scripts/opr_param_defs.py
+2
-1
dnn/src/common/batch_normalization.cpp
dnn/src/common/batch_normalization.cpp
+28
-20
dnn/src/common/opr_trait.h
dnn/src/common/opr_trait.h
+2
-2
dnn/src/cuda/batch_normalization/opr_impl.cpp
dnn/src/cuda/batch_normalization/opr_impl.cpp
+153
-36
dnn/src/cuda/batch_normalization/opr_impl.h
dnn/src/cuda/batch_normalization/opr_impl.h
+23
-25
dnn/src/naive/batch_normalization/opr_impl.cpp
dnn/src/naive/batch_normalization/opr_impl.cpp
+8
-6
dnn/src/naive/batch_normalization/opr_impl.h
dnn/src/naive/batch_normalization/opr_impl.h
+11
-7
dnn/src/rocm/batch_normalization/opr_impl.cpp
dnn/src/rocm/batch_normalization/opr_impl.cpp
+3
-2
dnn/src/rocm/batch_normalization/opr_impl.h
dnn/src/rocm/batch_normalization/opr_impl.h
+9
-7
未找到文件。
dnn/include/megdnn/oprs/nn.h
浏览文件 @
b3e54ead
...
@@ -682,7 +682,8 @@ public:
...
@@ -682,7 +682,8 @@ public:
* http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
* http://deeplearning.net/software/theano/library/tensor/nnet/neighbours.html
*
*
* \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
* \f$ dst_{n, c, oh, ow, wh, ww} = src_{n, c, ih+wh, iw+fw}\f$,
* where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1), iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$.
* where \f$ ih=-pad_h+oh*stride_h+(wh-1)*(dilation_h-1),
* iw=-pad_w+ow*stride_w+(ww-1)*(dilation_w-1)\f$.
*/
*/
virtual
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
virtual
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
=
0
;
_megdnn_workspace
workspace
)
=
0
;
...
@@ -724,7 +725,8 @@ protected:
...
@@ -724,7 +725,8 @@ protected:
};
};
class
SlidingWindowTransposeForward
:
public
SlidingWindowTransposeBase
{
class
SlidingWindowTransposeForward
:
public
SlidingWindowTransposeBase
{
DEF_OPR_IMPL
(
SlidingWindowTransposeForward
,
SlidingWindowTransposeBase
,
1
,
1
);
DEF_OPR_IMPL
(
SlidingWindowTransposeForward
,
SlidingWindowTransposeBase
,
1
,
1
);
public:
public:
/**
/**
...
@@ -744,7 +746,8 @@ protected:
...
@@ -744,7 +746,8 @@ protected:
using
SlidingWindowTranspose
=
SlidingWindowTransposeForward
;
using
SlidingWindowTranspose
=
SlidingWindowTransposeForward
;
class
SlidingWindowTransposeBackward
:
public
SlidingWindowTransposeBase
{
class
SlidingWindowTransposeBackward
:
public
SlidingWindowTransposeBase
{
DEF_OPR_IMPL
(
SlidingWindowTransposeBackward
,
SlidingWindowTransposeBase
,
1
,
1
);
DEF_OPR_IMPL
(
SlidingWindowTransposeBackward
,
SlidingWindowTransposeBase
,
1
,
1
);
public:
public:
/**
/**
...
@@ -975,7 +978,7 @@ protected:
...
@@ -975,7 +978,7 @@ protected:
};
};
class
BNForward
:
public
BNBase
{
class
BNForward
:
public
BNBase
{
DEF_OPR_IMPL
(
BNForward
,
BNBase
,
6
,
5
);
DEF_OPR_IMPL
(
BNForward
,
BNBase
,
6
,
6
);
public:
public:
/**
/**
...
@@ -986,10 +989,11 @@ public:
...
@@ -986,10 +989,11 @@ public:
* \param[out] dst (n, c, h, w)
* \param[out] dst (n, c, h, w)
* \param[out] mean (see m_param.ParamDim) Global mean.
* \param[out] mean (see m_param.ParamDim) Global mean.
* \param[out] variance (see m_param.ParamDim) Global variance.
* \param[out] variance (see m_param.ParamDim) Global variance.
* \
P
aram[out] batch_mean (see m_param.ParamDim)
* \
p
aram[out] batch_mean (see m_param.ParamDim)
* Optionally cached intermediate mean from forward pass
* Optionally cached intermediate mean from forward pass
* \
P
aram[out] batch_inv_variance (see m_param.ParamDim)
* \
p
aram[out] batch_inv_variance (see m_param.ParamDim)
* Optionally cached intermediate variance from forward pass
* Optionally cached intermediate variance from forward pass
* \param[out] reserve (see cudnnBatchNormalizationForwardTrainingEx)
* src and dst must have the same shape.
* src and dst must have the same shape.
* src and dst must be contiguous.
* src and dst must be contiguous.
*/
*/
...
@@ -998,17 +1002,20 @@ public:
...
@@ -998,17 +1002,20 @@ public:
_megdnn_tensor_inout
variance
,
_megdnn_tensor_inout
variance
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
=
0
;
_megdnn_tensor_out
reserve
,
_megdnn_tensor_out
dst
,
void
deduce_layout
(
const
TensorLayout
&
src
,
TensorLayout
&
bn_scale
,
_megdnn_workspace
workspace
)
=
0
;
TensorLayout
&
bn_bias
,
TensorLayout
&
mean
,
void
deduce_layout
(
const
TensorLayout
&
src
,
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
bn_bias
,
TensorLayout
&
mean
,
TensorLayout
&
variance
,
TensorLayout
&
batch_mean
,
TensorLayout
&
variance
,
TensorLayout
&
batch_mean
,
TensorLayout
&
batch_inv_variance
,
TensorLayout
&
dst
);
TensorLayout
&
batch_inv_variance
,
TensorLayout
&
reserve
,
TensorLayout
&
dst
);
virtual
size_t
get_workspace_in_bytes
(
virtual
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
src
,
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
bn_bias
,
const
TensorLayout
&
mean
,
const
TensorLayout
&
bn_bias
,
const
TensorLayout
&
mean
,
const
TensorLayout
&
variance
,
const
TensorLayout
&
batch_mean
,
const
TensorLayout
&
variance
,
const
TensorLayout
&
batch_mean
,
const
TensorLayout
&
batch_inv_variance
,
const
TensorLayout
&
batch_inv_variance
,
const
TensorLayout
&
reserve
,
const
TensorLayout
&
dst
)
=
0
;
const
TensorLayout
&
dst
)
=
0
;
virtual
size_t
get_reserve_in_bytes
(
const
TensorLayout
&
src
)
=
0
;
protected:
protected:
void
check_exec
(
const
TensorLayout
&
src
,
const
TensorLayout
&
bn_scale
,
void
check_exec
(
const
TensorLayout
&
src
,
const
TensorLayout
&
bn_scale
,
...
@@ -1016,12 +1023,13 @@ protected:
...
@@ -1016,12 +1023,13 @@ protected:
const
TensorLayout
&
variance
,
const
TensorLayout
&
variance
,
const
TensorLayout
&
batch_mean
,
const
TensorLayout
&
batch_mean
,
const
TensorLayout
&
batch_inv_variance
,
const
TensorLayout
&
batch_inv_variance
,
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
);
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
,
size_t
reserve_in_bytes
=
0
);
};
};
using
BN
=
BNForward
;
using
BN
=
BNForward
;
class
BNBackward
:
public
BNBase
{
class
BNBackward
:
public
BNBase
{
DEF_OPR_IMPL
(
BNBackward
,
BNBase
,
5
,
3
);
DEF_OPR_IMPL
(
BNBackward
,
BNBase
,
6
,
3
);
public:
public:
/**
/**
...
@@ -1035,19 +1043,23 @@ public:
...
@@ -1035,19 +1043,23 @@ public:
Calculated in the forwardpropagation.
Calculated in the forwardpropagation.
* \param[in] saved_batch_variance of the input batch.
* \param[in] saved_batch_variance of the input batch.
Calculated in the forwardpropagation.
Calculated in the forwardpropagation.
* \param[in] reserve (see cudnnBatchNormalizationBackwardEx)
*/
*/
virtual
void
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
virtual
void
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_variance
,
_megdnn_tensor_in
saved_batch_variance
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_out
d_bn_scale
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_in
reserve
,
_megdnn_tensor_out
d_bn_scale
,
_megdnn_tensor_out
d_bn_bias
,
_megdnn_tensor_out
dx
,
_megdnn_tensor_out
d_bn_bias
,
_megdnn_tensor_out
dx
,
_megdnn_workspace
workspace
)
=
0
;
_megdnn_workspace
workspace
)
=
0
;
virtual
size_t
get_workspace_in_bytes
(
virtual
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
x
,
const
TensorLayout
&
dy
,
const
TensorLayout
&
x
,
const
TensorLayout
&
dy
,
const
TensorLayout
&
saved_batch_mean
,
const
TensorLayout
&
saved_batch_mean
,
const
TensorLayout
&
saved_batch_variance
,
const
TensorLayout
&
saved_batch_variance
,
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
d_bn_scale
,
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
reserve
,
const
TensorLayout
&
d_bn_bias
,
const
TensorLayout
&
dx
)
=
0
;
const
TensorLayout
&
d_bn_scale
,
const
TensorLayout
&
d_bn_bias
,
const
TensorLayout
&
dx
)
=
0
;
virtual
size_t
get_reserve_in_bytes
(
const
TensorLayout
&
src
)
=
0
;
protected:
protected:
void
check_exec
(
const
TensorLayout
&
x
,
const
TensorLayout
&
dy
,
void
check_exec
(
const
TensorLayout
&
x
,
const
TensorLayout
&
dy
,
...
@@ -1056,7 +1068,7 @@ protected:
...
@@ -1056,7 +1068,7 @@ protected:
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
d_bn_scale
,
const
TensorLayout
&
d_bn_scale
,
const
TensorLayout
&
d_bn_bias
,
const
TensorLayout
&
dx
,
const
TensorLayout
&
d_bn_bias
,
const
TensorLayout
&
dx
,
size_t
workspace_in_bytes
);
size_t
workspace_in_bytes
,
size_t
reserve_in_bytes
=
0
);
};
};
class
LRNBase
:
public
OperatorBase
{
class
LRNBase
:
public
OperatorBase
{
...
...
dnn/scripts/opr_param_defs.py
浏览文件 @
b3e54ead
...
@@ -253,7 +253,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
...
@@ -253,7 +253,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
add_enum_alias
(
'Format'
,
'Convolution'
)
add_enum_alias
(
'Format'
,
'Convolution'
)
)
)
(
pdef
(
'AdaptivePooling'
,
version
=
0
,
is_legacy
=
True
).
(
pdef
(
'AdaptivePooling'
,
version
=
0
,
is_legacy
=
True
).
add_enum_alias
(
'Mode'
,
'PoolingV0'
).
add_enum_alias
(
'Mode'
,
'PoolingV0'
).
add_enum_alias
(
'Format'
,
'ConvolutionV0'
)
add_enum_alias
(
'Format'
,
'ConvolutionV0'
)
)
)
...
@@ -276,6 +276,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
...
@@ -276,6 +276,7 @@ pdef('Axis').add_fields('int32', 'axis', 0)
Doc
(
'DIM_11HW = 0'
,
'Dim of params (Sigma, Mu) is 1 x 1 x H x W'
),
Doc
(
'DIM_11HW = 0'
,
'Dim of params (Sigma, Mu) is 1 x 1 x H x W'
),
Doc
(
'DIM_1CHW = 1'
,
'Dim of params (Sigma, Mu) is 1 x C x H x W'
),
Doc
(
'DIM_1CHW = 1'
,
'Dim of params (Sigma, Mu) is 1 x C x H x W'
),
Doc
(
'DIM_1C11 = 2'
,
'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'
),
Doc
(
'DIM_1C11 = 2'
,
'Dim of params (Sigma, Mu) is 1 x C x 1 x 1'
),
Doc
(
'DIM_111C = 3'
,
'Dim of params (Sigma, Mu) is 1 x 1 x 1 x C'
),
name_field
=
'param_dim'
name_field
=
'param_dim'
).
).
add_enum
(
add_enum
(
...
...
dnn/src/common/batch_normalization.cpp
浏览文件 @
b3e54ead
...
@@ -4,9 +4,9 @@
...
@@ -4,9 +4,9 @@
*
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
software
*
software distributed under the License is distributed on an
*
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
*
"AS IS" BASIS, WITHOUT
ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#include "megdnn/oprs.h"
#include "megdnn/oprs.h"
...
@@ -14,28 +14,32 @@
...
@@ -14,28 +14,32 @@
namespace
megdnn
{
namespace
megdnn
{
void
BNForward
::
deduce_layout
(
const
TensorLayout
&
src
,
TensorLayout
&
,
void
BNForward
::
deduce_layout
(
const
TensorLayout
&
src
,
const
TensorLayout
&
,
TensorLayout
&
,
TensorLayout
&
,
TensorLayout
&
,
const
TensorLayout
&
,
TensorLayout
&
,
TensorLayout
&
,
TensorLayout
&
,
TensorLayout
&
,
TensorLayout
&
dst
)
{
TensorLayout
&
,
TensorLayout
&
,
TensorLayout
&
reserve
,
TensorLayout
&
dst
)
{
reserve
=
{{
get_reserve_in_bytes
(
src
)},
dtype
::
Byte
()};
dst
=
src
;
dst
=
src
;
}
}
void
BNForward
::
check_exec
(
const
TensorLayout
&
src
,
const
TensorLayout
&
bn_scale
,
void
BNForward
::
check_exec
(
const
TensorLayout
&
src
,
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
bn_bias
,
const
TensorLayout
&
mean
,
const
TensorLayout
&
bn_bias
,
const
TensorLayout
&
mean
,
const
TensorLayout
&
variance
,
const
TensorLayout
&
variance
,
const
TensorLayout
&
batch_mean
,
const
TensorLayout
&
batch_mean
,
const
TensorLayout
&
batch_inv_variance
,
const
TensorLayout
&
dst
,
const
TensorLayout
&
batch_inv_variance
,
size_t
workspace_in_bytes
,
size_t
reserve_in_bytes
)
{
const
TensorLayout
&
dst
,
size_t
workspace_in_bytes
)
{
megdnn_assert_contiguous
(
src
);
megdnn_assert_contiguous
(
src
);
megdnn_assert_eq_layout
(
src
,
dst
);
megdnn_assert_eq_layout
(
src
,
dst
);
megdnn_assert_eq_layout
(
bn_scale
,
bn_bias
);
megdnn_assert_eq_layout
(
bn_scale
,
bn_bias
);
megdnn_assert
(
src
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
);
megdnn_assert
(
src
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
);
megdnn_assert
(
bn_scale
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
);
megdnn_assert
(
bn_scale
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
);
auto
required_workspace_in_bytes
=
auto
required_workspace_in_bytes
=
get_workspace_in_bytes
(
get_workspace_in_bytes
(
src
,
bn_scale
,
bn_bias
,
mean
,
variance
,
src
,
bn_scale
,
bn_bias
,
mean
,
variance
,
batch_mean
,
batch_mean
,
batch_inv_variance
,
dst
);
batch_inv_variance
,
{}
,
dst
);
megdnn_assert
(
workspace_in_bytes
>=
required_workspace_in_bytes
);
megdnn_assert
(
workspace_in_bytes
>=
required_workspace_in_bytes
);
auto
required_reserve_in_bytes
=
get_reserve_in_bytes
(
src
);
megdnn_assert
(
reserve_in_bytes
>=
required_reserve_in_bytes
);
}
}
void
BNBackward
::
check_exec
(
const
TensorLayout
&
x
,
const
TensorLayout
&
dy
,
void
BNBackward
::
check_exec
(
const
TensorLayout
&
x
,
const
TensorLayout
&
dy
,
...
@@ -44,7 +48,8 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy,
...
@@ -44,7 +48,8 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy,
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
d_bn_scale
,
const
TensorLayout
&
d_bn_scale
,
const
TensorLayout
&
d_bn_bias
,
const
TensorLayout
&
d_bn_bias
,
const
TensorLayout
&
dx
,
size_t
workspace_in_bytes
)
{
const
TensorLayout
&
dx
,
size_t
workspace_in_bytes
,
size_t
reserve_in_bytes
)
{
megdnn_assert_contiguous
(
x
);
megdnn_assert_contiguous
(
x
);
megdnn_assert_eq_layout
(
x
,
dy
);
megdnn_assert_eq_layout
(
x
,
dy
);
megdnn_assert_eq_layout
(
x
,
dx
);
megdnn_assert_eq_layout
(
x
,
dx
);
...
@@ -54,11 +59,14 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy,
...
@@ -54,11 +59,14 @@ void BNBackward::check_exec(const TensorLayout& x, const TensorLayout& dy,
megdnn_assert_eq_layout
(
saved_batch_mean
,
bn_scale
);
megdnn_assert_eq_layout
(
saved_batch_mean
,
bn_scale
);
megdnn_assert
(
x
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
);
megdnn_assert
(
x
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
);
megdnn_assert
(
bn_scale
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
);
megdnn_assert
(
bn_scale
.
dtype
.
category
()
==
DTypeCategory
::
FLOAT
);
auto
required_workspace_in_bytes
=
auto
required_workspace_in_bytes
=
get_workspace_in_bytes
(
get_workspace_in_bytes
(
x
,
dy
,
saved_batch_mean
,
saved_batch_variance
,
x
,
dy
,
saved_batch_mean
,
saved_batch_variance
,
bn_scale
,
{}
,
bn_scale
,
d_bn_scale
,
d_bn_bias
,
dx
);
d_bn_scale
,
d_bn_bias
,
dx
);
megdnn_assert
(
workspace_in_bytes
>=
required_workspace_in_bytes
);
megdnn_assert
(
workspace_in_bytes
>=
required_workspace_in_bytes
);
megdnn_assert
(
param
().
fwd_mode
==
Param
::
FwdMode
::
TRAINING
,
"BNBackward only support TRAINING mode"
);
auto
required_reserve_in_bytes
=
get_reserve_in_bytes
(
x
);
megdnn_assert
(
reserve_in_bytes
>=
required_reserve_in_bytes
);
megdnn_assert
(
param
().
fwd_mode
==
Param
::
FwdMode
::
TRAINING
,
"BNBackward only support TRAINING mode"
);
}
}
}
// namespace megdnn
}
// namespace megdnn
...
...
dnn/src/common/opr_trait.h
浏览文件 @
b3e54ead
...
@@ -55,8 +55,8 @@ DEF(GroupLocalBackwardData, 3, true, false);
...
@@ -55,8 +55,8 @@ DEF(GroupLocalBackwardData, 3, true, false);
DEF
(
GroupLocalBackwardFilter
,
3
,
true
,
false
);
DEF
(
GroupLocalBackwardFilter
,
3
,
true
,
false
);
DEF
(
LRNForward
,
2
,
true
,
true
);
DEF
(
LRNForward
,
2
,
true
,
true
);
DEF
(
LRNBackward
,
4
,
true
,
false
);
DEF
(
LRNBackward
,
4
,
true
,
false
);
DEF
(
BNForward
,
8
,
true
,
true
);
DEF
(
BNForward
,
9
,
true
,
true
);
DEF
(
BNBackward
,
8
,
true
,
false
);
DEF
(
BNBackward
,
9
,
true
,
false
);
DEF
(
ROIPoolingForward
,
4
,
true
,
false
);
DEF
(
ROIPoolingForward
,
4
,
true
,
false
);
DEF
(
ROIPoolingBackward
,
5
,
true
,
false
);
DEF
(
ROIPoolingBackward
,
5
,
true
,
false
);
DEF
(
CorrelationForward
,
3
,
true
,
true
);
DEF
(
CorrelationForward
,
3
,
true
,
true
);
...
...
dnn/src/cuda/batch_normalization/opr_impl.cpp
浏览文件 @
b3e54ead
...
@@ -4,9 +4,9 @@
...
@@ -4,9 +4,9 @@
*
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
software
*
software distributed under the License is distributed on an
*
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
*
"AS IS" BASIS, WITHOUT
ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#include "./opr_impl.h"
#include "./opr_impl.h"
...
@@ -17,9 +17,11 @@ namespace cuda {
...
@@ -17,9 +17,11 @@ namespace cuda {
namespace
batch_normalization
{
namespace
batch_normalization
{
void
BNTensorDescHolder
::
setup
(
const
TensorLayout
&
x
,
BNTensorDescHolder
::
BNTensorDescHolder
(
const
TensorLayout
&
x
,
const
ParamDim
&
param_dim
)
{
const
ParamDim
&
param_dim
,
const
FwdMode
&
fwd_mode
)
{
TensorShape
xy_shape
(
x
);
TensorShape
xy_shape
(
x
);
Format
xy_format
=
Format
::
NCHW
;
switch
(
param_dim
)
{
switch
(
param_dim
)
{
case
ParamDim
::
DIM_11HW
:
case
ParamDim
::
DIM_11HW
:
...
@@ -34,50 +36,116 @@ void BNTensorDescHolder::setup(const TensorLayout& x,
...
@@ -34,50 +36,116 @@ void BNTensorDescHolder::setup(const TensorLayout& x,
case
ParamDim
::
DIM_1C11
:
case
ParamDim
::
DIM_1C11
:
bn_mode
=
CUDNN_BATCHNORM_SPATIAL
;
bn_mode
=
CUDNN_BATCHNORM_SPATIAL
;
break
;
break
;
case
ParamDim
::
DIM_111C
:
bn_mode
=
CUDNN_BATCHNORM_SPATIAL
;
xy_format
=
Format
::
NHWC
;
#if CUDNN_VERSION >= 7410
if
(
fwd_mode
==
FwdMode
::
TRAINING
)
{
bn_mode
=
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
;
}
#endif // CUDNN_VERSION >= 7400
break
;
default:
default:
megdnn_throw
(
"Unknown param dim type of batch normalization."
);
megdnn_throw
(
"Unknown param dim type of batch normalization."
);
}
}
xy_desc
.
set
(
TensorLayout
(
xy_shape
,
x
.
dtype
));
xy_desc
.
set
(
TensorLayout
(
xy_shape
,
x
.
dtype
)
,
xy_format
);
param_desc
.
set
(
xy_desc
.
desc
,
bn_mode
);
param_desc
.
set
(
xy_desc
.
desc
,
bn_mode
);
}
}
size_t
get_reserve_size
(
const
cudnnHandle_t
&
handle
,
const
BNTensorDescHolder
&
tensor_desc
)
{
#if CUDNN_VERSION >= 7410
size_t
reserve_size
;
cudnn_check
(
cudnnGetBatchNormalizationTrainingExReserveSpaceSize
(
handle
,
tensor_desc
.
bn_mode
,
CUDNN_BATCHNORM_OPS_BN
,
nullptr
,
// activationDesc
tensor_desc
.
xy_desc
.
desc
,
// xDesc
&
reserve_size
));
return
reserve_size
;
#else
return
0
;
#endif // CUDNN_VERSION >= 7410
}
}
// namespace batch_normalization
}
// namespace batch_normalization
using
batch_normalization
::
BNTensorDescHolder
;
size_t
BNForwardImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
#if CUDNN_VERSION >= 7410
auto
handle
=
cudnn_handle
(
this
->
handle
());
BNTensorDescHolder
tensor_desc
(
src
,
m_param
.
param_dim
,
m_param
.
fwd_mode
);
size_t
workspace_size
;
cudnn_check
(
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize
(
handle
,
tensor_desc
.
bn_mode
,
CUDNN_BATCHNORM_OPS_BN
,
tensor_desc
.
xy_desc
.
desc
,
// xDesc
tensor_desc
.
xy_desc
.
desc
,
// yDesc
tensor_desc
.
xy_desc
.
desc
,
// zDesc
tensor_desc
.
param_desc
.
desc
,
// bnScaleBiasMeanVarDesc
nullptr
,
// activationDesc
&
workspace_size
));
return
workspace_size
;
#else
return
0
;
#endif // CUDNN_VERSION >= 7410
}
size_t
BNForwardImpl
::
get_reserve_in_bytes
(
const
TensorLayout
&
src
)
{
BNTensorDescHolder
tensor_desc
(
src
,
m_param
.
param_dim
,
m_param
.
fwd_mode
);
return
batch_normalization
::
get_reserve_size
(
cudnn_handle
(
this
->
handle
()),
tensor_desc
);
}
void
BNForwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
bn_scale
,
void
BNForwardImpl
::
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_in
bn_bias
,
_megdnn_tensor_out
mean
,
_megdnn_tensor_in
bn_bias
,
_megdnn_tensor_out
mean
,
_megdnn_tensor_out
variance
,
_megdnn_tensor_out
variance
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
_megdnn_tensor_out
reserve
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
check_exec
(
src
.
layout
,
bn_scale
.
layout
,
bn_bias
.
layout
,
mean
.
layout
,
check_exec
(
src
.
layout
,
bn_scale
.
layout
,
bn_bias
.
layout
,
mean
.
layout
,
variance
.
layout
,
batch_mean
.
layout
,
batch_inv_variance
.
layout
,
variance
.
layout
,
batch_mean
.
layout
,
batch_inv_variance
.
layout
,
dst
.
layout
,
workspace
.
size
);
dst
.
layout
,
workspace
.
size
,
reserve
.
layout
.
access_bytes
()
);
auto
handle
=
cudnn_handle
(
this
->
handle
());
auto
handle
=
cudnn_handle
(
this
->
handle
());
m_tensor_desc
.
setup
(
src
.
layout
,
m_param
.
param_dim
);
BNTensorDescHolder
tensor_desc
(
src
.
layout
,
m_param
.
param_dim
,
m_param
.
fwd_mode
);
float
alpha
=
1.0
f
,
beta
=
0.0
f
;
float
alpha
=
1.0
f
,
beta
=
0.0
f
;
switch
(
m_param
.
fwd_mode
)
{
switch
(
m_param
.
fwd_mode
)
{
case
param
::
BN
::
FwdMode
::
TRAINING
:
case
param
::
BN
::
FwdMode
::
TRAINING
:
#if CUDNN_VERSION >= 7410
cudnn_check
(
cudnnBatchNormalizationForwardTrainingEx
(
handle
,
tensor_desc
.
bn_mode
,
CUDNN_BATCHNORM_OPS_BN
,
&
alpha
,
&
beta
,
// one & zero
tensor_desc
.
xy_desc
.
desc
,
src
.
raw_ptr
,
// xDesc & x
nullptr
,
nullptr
,
// zDesc & z
tensor_desc
.
xy_desc
.
desc
,
dst
.
raw_ptr
,
// yDesc & y
tensor_desc
.
param_desc
.
desc
,
// bnScaleBiasMeanVarDesc
bn_scale
.
raw_ptr
,
bn_bias
.
raw_ptr
,
m_param
.
avg_factor
,
mean
.
raw_ptr
,
variance
.
raw_ptr
,
m_param
.
epsilon
,
batch_mean
.
raw_ptr
,
batch_inv_variance
.
raw_ptr
,
nullptr
,
workspace
.
raw_ptr
,
workspace
.
size
,
reserve
.
raw_ptr
,
reserve
.
layout
.
access_bytes
()));
#else
cudnn_check
(
cudnnBatchNormalizationForwardTraining
(
cudnn_check
(
cudnnBatchNormalizationForwardTraining
(
handle
,
m_tensor_desc
.
bn_mode
,
handle
,
tensor_desc
.
bn_mode
,
&
alpha
,
&
beta
,
&
alpha
,
&
beta
,
tensor_desc
.
xy_desc
.
desc
,
src
.
raw_ptr
,
// xDesc & x
m_tensor_desc
.
xy_desc
.
desc
,
// xDesc
tensor_desc
.
xy_desc
.
desc
,
dst
.
raw_ptr
,
// yDesc & y
src
.
raw_ptr
,
// x
tensor_desc
.
param_desc
.
desc
,
// bnScaleBiasMeanVarDesc
m_tensor_desc
.
xy_desc
.
desc
,
// yDesc
dst
.
raw_ptr
,
// y
m_tensor_desc
.
param_desc
.
desc
,
// bnScaleBiasMeanVarDesc
bn_scale
.
raw_ptr
,
bn_bias
.
raw_ptr
,
m_param
.
avg_factor
,
bn_scale
.
raw_ptr
,
bn_bias
.
raw_ptr
,
m_param
.
avg_factor
,
mean
.
raw_ptr
,
variance
.
raw_ptr
,
m_param
.
epsilon
,
mean
.
raw_ptr
,
variance
.
raw_ptr
,
m_param
.
epsilon
,
batch_mean
.
raw_ptr
,
batch_inv_variance
.
raw_ptr
));
batch_mean
.
raw_ptr
,
batch_inv_variance
.
raw_ptr
));
#endif // CUDNN_VERSION >= 7410
break
;
break
;
case
param
::
BN
::
FwdMode
::
INFERENCE
:
case
param
::
BN
::
FwdMode
::
INFERENCE
:
cudnn_check
(
cudnnBatchNormalizationForwardInference
(
cudnn_check
(
cudnnBatchNormalizationForwardInference
(
handle
,
m_tensor_desc
.
bn_mode
,
handle
,
tensor_desc
.
bn_mode
,
&
alpha
,
&
beta
,
&
alpha
,
&
beta
,
tensor_desc
.
xy_desc
.
desc
,
src
.
raw_ptr
,
m_tensor_desc
.
xy_desc
.
desc
,
src
.
raw_ptr
,
tensor_desc
.
xy_desc
.
desc
,
dst
.
raw_ptr
,
m_tensor_desc
.
xy_desc
.
desc
,
dst
.
raw_ptr
,
tensor_desc
.
param_desc
.
desc
,
bn_scale
.
raw_ptr
,
m_tensor_desc
.
param_desc
.
desc
,
bn_scale
.
raw_ptr
,
bn_bias
.
raw_ptr
,
mean
.
raw_ptr
,
variance
.
raw_ptr
,
bn_bias
.
raw_ptr
,
mean
.
raw_ptr
,
variance
.
raw_ptr
,
m_param
.
epsilon
));
m_param
.
epsilon
));
break
;
break
;
...
@@ -86,30 +154,79 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
...
@@ -86,30 +154,79 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
}
}
}
}
size_t
BNBackwardImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
x
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
#if CUDNN_VERSION >= 7410
auto
handle
=
cudnn_handle
(
this
->
handle
());
BNTensorDescHolder
tensor_desc
(
x
,
m_param
.
param_dim
,
m_param
.
fwd_mode
);
size_t
workspace_size
;
cudnn_check
(
cudnnGetBatchNormalizationBackwardExWorkspaceSize
(
handle
,
tensor_desc
.
bn_mode
,
CUDNN_BATCHNORM_OPS_BN
,
tensor_desc
.
xy_desc
.
desc
,
// xDesc
tensor_desc
.
xy_desc
.
desc
,
// yDesc
tensor_desc
.
xy_desc
.
desc
,
// dyDesc
nullptr
,
// dzDesc
tensor_desc
.
xy_desc
.
desc
,
// dxDesc
tensor_desc
.
param_desc
.
desc
,
// dBnScaleBiasDesc
nullptr
,
// activationDesc
&
workspace_size
));
return
workspace_size
;
#else
return
0
;
#endif // CUDNN_VERSION >= 7410
}
size_t
BNBackwardImpl
::
get_reserve_in_bytes
(
const
TensorLayout
&
src
)
{
BNTensorDescHolder
tensor_desc
(
src
,
m_param
.
param_dim
,
m_param
.
fwd_mode
);
return
batch_normalization
::
get_reserve_size
(
cudnn_handle
(
this
->
handle
()),
tensor_desc
);
}
void
BNBackwardImpl
::
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
void
BNBackwardImpl
::
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_in
reserve
,
_megdnn_tensor_out
d_bn_scale
,
_megdnn_tensor_out
d_bn_scale
,
_megdnn_tensor_out
d_bn_bias
,
_megdnn_tensor_out
d_bn_bias
,
_megdnn_tensor_out
dx
,
_megdnn_
tensor_out
dx
,
_megdnn_
workspace
workspace
)
{
_megdnn_workspace
workspace
)
{
check_exec
(
x
.
layout
,
dy
.
layout
,
saved_batch_mean
.
layout
,
check_exec
(
x
.
layout
,
dy
.
layout
,
saved_batch_mean
.
layout
,
saved_batch_inv_variance
.
layout
,
bn_scale
.
layout
,
saved_batch_inv_variance
.
layout
,
bn_scale
.
layout
,
d_bn_scale
.
layout
,
d_bn_bias
.
layout
,
dx
.
layout
,
d_bn_scale
.
layout
,
d_bn_bias
.
layout
,
dx
.
layout
,
workspace
.
size
,
workspace
.
size
);
reserve
.
layout
.
access_bytes
()
);
auto
handle
=
cudnn_handle
(
this
->
handle
());
auto
handle
=
cudnn_handle
(
this
->
handle
());
m_tensor_desc
.
setup
(
x
.
layout
,
m_param
.
param_dim
);
BNTensorDescHolder
tensor_desc
(
x
.
layout
,
m_param
.
param_dim
,
m_param
.
fwd_mode
);
float
alpha
=
1.0
,
beta
=
0.0
;
float
alpha
=
1.0
,
beta
=
0.0
;
#if CUDNN_VERSION >= 7410
cudnn_check
(
cudnnBatchNormalizationBackwardEx
(
handle
,
tensor_desc
.
bn_mode
,
CUDNN_BATCHNORM_OPS_BN
,
&
alpha
,
&
beta
,
&
alpha
,
&
beta
,
tensor_desc
.
xy_desc
.
desc
,
x
.
raw_ptr
,
// xDesc & x
nullptr
,
nullptr
,
// yDesc & y
tensor_desc
.
xy_desc
.
desc
,
dy
.
raw_ptr
,
// dyDesc & dy
nullptr
,
nullptr
,
// dzDesc & dz
tensor_desc
.
xy_desc
.
desc
,
dx
.
raw_ptr
,
// dxDesc & dx
tensor_desc
.
param_desc
.
desc
,
bn_scale
.
raw_ptr
,
// bnScale
nullptr
,
// bnBias
d_bn_scale
.
raw_ptr
,
d_bn_bias
.
raw_ptr
,
// dScale, dBias
m_param
.
epsilon
,
saved_batch_mean
.
raw_ptr
,
saved_batch_inv_variance
.
raw_ptr
,
nullptr
,
workspace
.
raw_ptr
,
workspace
.
size
,
reserve
.
raw_ptr
,
reserve
.
layout
.
access_bytes
()));
#else
cudnn_check
(
cudnnBatchNormalizationBackward
(
cudnn_check
(
cudnnBatchNormalizationBackward
(
handle
,
m_tensor_desc
.
bn_mode
,
handle
,
tensor_desc
.
bn_mode
,
&
alpha
,
&
beta
,
&
alpha
,
&
beta
,
&
alpha
,
&
beta
,
&
alpha
,
&
beta
,
tensor_desc
.
xy_desc
.
desc
,
x
.
raw_ptr
,
// xDesc & x
m_tensor_desc
.
xy_desc
.
desc
,
x
.
raw_ptr
,
tensor_desc
.
xy_desc
.
desc
,
dy
.
raw_ptr
,
// dyDesc & dy
m_tensor_desc
.
xy_desc
.
desc
,
dy
.
raw_ptr
,
tensor_desc
.
xy_desc
.
desc
,
dx
.
raw_ptr
,
// dxDesc & dx
m_tensor_desc
.
xy_desc
.
desc
,
dx
.
raw_ptr
,
tensor_desc
.
param_desc
.
desc
,
bn_scale
.
raw_ptr
,
// bnScale
m_tensor_desc
.
param_desc
.
desc
,
bn_scale
.
raw_ptr
,
d_bn_scale
.
raw_ptr
,
d_bn_bias
.
raw_ptr
,
// dScale, dBias
d_bn_scale
.
raw_ptr
,
d_bn_bias
.
raw_ptr
,
m_param
.
epsilon
,
m_param
.
epsilon
,
saved_batch_mean
.
raw_ptr
,
saved_batch_mean
.
raw_ptr
,
saved_batch_inv_variance
.
raw_ptr
));
saved_batch_inv_variance
.
raw_ptr
));
#endif
}
}
}
// namespace cuda
}
// namespace cuda
...
...
dnn/src/cuda/batch_normalization/opr_impl.h
浏览文件 @
b3e54ead
...
@@ -4,9 +4,9 @@
...
@@ -4,9 +4,9 @@
*
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
software
*
software distributed under the License is distributed on an
*
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
*
"AS IS" BASIS, WITHOUT
ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#pragma once
#pragma once
#include "megdnn/oprs.h"
#include "megdnn/oprs.h"
...
@@ -20,14 +20,20 @@ namespace batch_normalization {
...
@@ -20,14 +20,20 @@ namespace batch_normalization {
struct
BNTensorDescHolder
{
struct
BNTensorDescHolder
{
using
ParamDim
=
param
::
BN
::
ParamDim
;
using
ParamDim
=
param
::
BN
::
ParamDim
;
using
FwdMode
=
param
::
BN
::
FwdMode
;
using
Format
=
param
::
Convolution
::
Format
;
TensorDesc
xy_desc
;
TensorDesc
xy_desc
;
BNParamDesc
param_desc
;
BNParamDesc
param_desc
;
cudnnBatchNormMode_t
bn_mode
;
cudnnBatchNormMode_t
bn_mode
;
void
setup
(
const
TensorLayout
&
x
,
const
ParamDim
&
param_dim
);
BNTensorDescHolder
(
const
TensorLayout
&
x
,
const
ParamDim
&
param_dim
,
const
FwdMode
&
fwd_mode
);
};
};
size_t
get_reserve_size
(
const
cudnnHandle_t
&
handle
,
const
BNTensorDescHolder
&
tensor_desc
);
}
// namespace batch_normalization
}
// namespace batch_normalization
class
BNForwardImpl
final
:
public
BNForward
{
class
BNForwardImpl
final
:
public
BNForward
{
...
@@ -36,19 +42,15 @@ public:
...
@@ -36,19 +42,15 @@ public:
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
bn_scale
,
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_in
bn_bias
,
_megdnn_tensor_out
mean
,
_megdnn_tensor_in
bn_bias
,
_megdnn_tensor_out
mean
,
_megdnn_tensor_out
variance
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
variance
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
dst
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
reserve
,
_megdnn_workspace
workspace
)
override
;
_megdnn_
tensor_out
dst
,
_megdnn_
workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
src
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
const
TensorLayout
&
)
override
;
return
0
;
size_t
get_reserve_in_bytes
(
const
TensorLayout
&
src
)
override
;
}
private:
batch_normalization
::
BNTensorDescHolder
m_tensor_desc
;
};
};
class
BNBackwardImpl
final
:
public
BNBackward
{
class
BNBackwardImpl
final
:
public
BNBackward
{
...
@@ -57,20 +59,16 @@ public:
...
@@ -57,20 +59,16 @@ public:
void
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
void
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_
out
d_bn_scal
e
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_
in
reserv
e
,
_megdnn_tensor_out
d_bn_
bias
,
_megdnn_tensor_out
dx
,
_megdnn_tensor_out
d_bn_
scale
,
_megdnn_tensor_out
d_bn_bias
,
_megdnn_workspace
workspace
)
override
;
_megdnn_
tensor_out
dx
,
_megdnn_
workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
x
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
const
TensorLayout
&
)
override
;
return
0
;
size_t
get_reserve_in_bytes
(
const
TensorLayout
&
src
)
override
;
}
private:
batch_normalization
::
BNTensorDescHolder
m_tensor_desc
;
};
};
}
// namespace cuda
}
// namespace cuda
...
...
dnn/src/naive/batch_normalization/opr_impl.cpp
浏览文件 @
b3e54ead
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#include "src/naive/batch_normalization/opr_impl.h"
#include "src/naive/batch_normalization/opr_impl.h"
...
@@ -219,7 +220,8 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
...
@@ -219,7 +220,8 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_inout
variance
,
_megdnn_tensor_inout
variance
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
_megdnn_tensor_out
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
check_exec
(
src
.
layout
,
bn_scale
.
layout
,
bn_bias
.
layout
,
mean
.
layout
,
check_exec
(
src
.
layout
,
bn_scale
.
layout
,
bn_bias
.
layout
,
mean
.
layout
,
variance
.
layout
,
batch_mean
.
layout
,
batch_inv_variance
.
layout
,
variance
.
layout
,
batch_mean
.
layout
,
batch_inv_variance
.
layout
,
dst
.
layout
,
workspace
.
size
);
dst
.
layout
,
workspace
.
size
);
...
@@ -263,7 +265,7 @@ WorkspaceBundle BNBackwardImpl::get_workspace_bundle(size_t x_size,
...
@@ -263,7 +265,7 @@ WorkspaceBundle BNBackwardImpl::get_workspace_bundle(size_t x_size,
size_t
BNBackwardImpl
::
get_workspace_in_bytes
(
size_t
BNBackwardImpl
::
get_workspace_in_bytes
(
const
TensorLayout
&
x
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
x
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
{
auto
x_size
=
x
.
total_nr_elems
(),
param_size
=
bn_scale
.
total_nr_elems
();
auto
x_size
=
x
.
total_nr_elems
(),
param_size
=
bn_scale
.
total_nr_elems
();
return
get_workspace_bundle
(
x_size
,
param_size
).
total_size_in_bytes
();
return
get_workspace_bundle
(
x_size
,
param_size
).
total_size_in_bytes
();
}
}
...
@@ -271,7 +273,7 @@ size_t BNBackwardImpl::get_workspace_in_bytes(
...
@@ -271,7 +273,7 @@ size_t BNBackwardImpl::get_workspace_in_bytes(
void
BNBackwardImpl
::
exec
(
_megdnn_tensor_in
x_in
,
_megdnn_tensor_in
dy_in
,
void
BNBackwardImpl
::
exec
(
_megdnn_tensor_in
x_in
,
_megdnn_tensor_in
dy_in
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_in
,
_megdnn_tensor_out
d_bn_scale
,
_megdnn_tensor_out
d_bn_scale
,
_megdnn_tensor_out
d_bn_bias
,
_megdnn_tensor_out
d_bn_bias
,
_megdnn_tensor_out
dx_out
,
_megdnn_tensor_out
dx_out
,
...
...
dnn/src/naive/batch_normalization/opr_impl.h
浏览文件 @
b3e54ead
...
@@ -6,7 +6,8 @@
...
@@ -6,7 +6,8 @@
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
*/
#pragma once
#pragma once
#include "megdnn/oprs.h"
#include "megdnn/oprs.h"
...
@@ -21,16 +22,17 @@ public:
...
@@ -21,16 +22,17 @@ public:
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
bn_scale
,
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_in
bn_bias
,
_megdnn_tensor_out
mean
,
_megdnn_tensor_in
bn_bias
,
_megdnn_tensor_out
mean
,
_megdnn_tensor_out
variance
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
variance
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
dst
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
reserve
,
_megdnn_workspace
workspace
)
override
;
_megdnn_
tensor_out
dst
,
_megdnn_
workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
const
TensorLayout
&
)
override
{
return
0
;
return
0
;
}
}
size_t
get_reserve_in_bytes
(
const
TensorLayout
&
)
override
{
return
0
;
}
};
};
class
BNBackwardImpl
final
:
public
BNBackward
{
class
BNBackwardImpl
final
:
public
BNBackward
{
...
@@ -39,15 +41,17 @@ public:
...
@@ -39,15 +41,17 @@ public:
void
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
void
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_
out
d_bn_scal
e
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_
in
reserv
e
,
_megdnn_tensor_out
d_bn_
bias
,
_megdnn_tensor_out
dx
,
_megdnn_tensor_out
d_bn_
scale
,
_megdnn_tensor_out
d_bn_bias
,
_megdnn_workspace
workspace
)
override
;
_megdnn_
tensor_out
dx
,
_megdnn_
workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
x
,
const
TensorLayout
&
,
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
x
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
bn_scale
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
;
const
TensorLayout
&
)
override
;
size_t
get_reserve_in_bytes
(
const
TensorLayout
&
)
override
{
return
0
;
}
private:
private:
WorkspaceBundle
get_workspace_bundle
(
size_t
x_size
,
size_t
param_size
,
WorkspaceBundle
get_workspace_bundle
(
size_t
x_size
,
size_t
param_size
,
...
...
dnn/src/rocm/batch_normalization/opr_impl.cpp
浏览文件 @
b3e54ead
...
@@ -49,7 +49,8 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
...
@@ -49,7 +49,8 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
_megdnn_tensor_out
variance
,
_megdnn_tensor_out
variance
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
_megdnn_tensor_out
,
_megdnn_tensor_out
dst
,
_megdnn_workspace
workspace
)
{
check_exec
(
src
.
layout
,
bn_scale
.
layout
,
bn_bias
.
layout
,
mean
.
layout
,
check_exec
(
src
.
layout
,
bn_scale
.
layout
,
bn_bias
.
layout
,
mean
.
layout
,
variance
.
layout
,
batch_mean
.
layout
,
batch_inv_variance
.
layout
,
variance
.
layout
,
batch_mean
.
layout
,
batch_inv_variance
.
layout
,
dst
.
layout
,
workspace
.
size
);
dst
.
layout
,
workspace
.
size
);
...
@@ -88,7 +89,7 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
...
@@ -88,7 +89,7 @@ void BNForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in bn_scale,
void
BNBackwardImpl
::
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
void
BNBackwardImpl
::
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_in
,
_megdnn_tensor_out
d_bn_scale
,
_megdnn_tensor_out
d_bn_scale
,
_megdnn_tensor_out
d_bn_bias
,
_megdnn_tensor_out
dx
,
_megdnn_tensor_out
d_bn_bias
,
_megdnn_tensor_out
dx
,
_megdnn_workspace
workspace
)
{
_megdnn_workspace
workspace
)
{
...
...
dnn/src/rocm/batch_normalization/opr_impl.h
浏览文件 @
b3e54ead
...
@@ -37,16 +37,17 @@ public:
...
@@ -37,16 +37,17 @@ public:
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
bn_scale
,
void
exec
(
_megdnn_tensor_in
src
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_in
bn_bias
,
_megdnn_tensor_out
mean
,
_megdnn_tensor_in
bn_bias
,
_megdnn_tensor_out
mean
,
_megdnn_tensor_out
variance
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
variance
,
_megdnn_tensor_out
batch_mean
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
dst
,
_megdnn_tensor_out
batch_inv_variance
,
_megdnn_tensor_out
reserve
,
_megdnn_workspace
workspace
)
override
;
_megdnn_
tensor_out
dst
,
_megdnn_
workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
const
TensorLayout
&
)
override
{
return
0
;
return
0
;
}
}
size_t
get_reserve_in_bytes
(
const
TensorLayout
&
)
override
{
return
0
;
}
private:
private:
batch_normalization
::
BNTensorDescHolder
m_tensor_desc
;
batch_normalization
::
BNTensorDescHolder
m_tensor_desc
;
...
@@ -58,17 +59,18 @@ public:
...
@@ -58,17 +59,18 @@ public:
void
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
void
exec
(
_megdnn_tensor_in
x
,
_megdnn_tensor_in
dy
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_mean
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
saved_batch_inv_variance
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_
out
d_bn_scal
e
,
_megdnn_tensor_in
bn_scale
,
_megdnn_tensor_
in
reserv
e
,
_megdnn_tensor_out
d_bn_
bias
,
_megdnn_tensor_out
dx
,
_megdnn_tensor_out
d_bn_
scale
,
_megdnn_tensor_out
d_bn_bias
,
_megdnn_workspace
workspace
)
override
;
_megdnn_
tensor_out
dx
,
_megdnn_
workspace
workspace
)
override
;
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
size_t
get_workspace_in_bytes
(
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
,
const
TensorLayout
&
)
override
{
const
TensorLayout
&
)
override
{
return
0
;
return
0
;
}
}
size_t
get_reserve_in_bytes
(
const
TensorLayout
&
)
override
{
return
0
;
}
private:
private:
batch_normalization
::
BNTensorDescHolder
m_tensor_desc
;
batch_normalization
::
BNTensorDescHolder
m_tensor_desc
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录