Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5e813b53
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5e813b53
编写于
12月 01, 2019
作者:
J
Jie Fang
提交者:
gongweibao
12月 01, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
nhwc optimization for batchnorm (#21090)
上级
fce24315
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
461 addition
and
77 deletion
+461
-77
paddle/fluid/framework/grad_op_desc_maker.h
paddle/fluid/framework/grad_op_desc_maker.h
+4
-0
paddle/fluid/imperative/dygraph_grad_maker.h
paddle/fluid/imperative/dygraph_grad_maker.h
+6
-0
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+41
-19
paddle/fluid/operators/batch_norm_op.cu
paddle/fluid/operators/batch_norm_op.cu
+249
-35
paddle/fluid/operators/batch_norm_op.h
paddle/fluid/operators/batch_norm_op.h
+93
-6
paddle/fluid/platform/dynload/cudnn.cc
paddle/fluid/platform/dynload/cudnn.cc
+4
-0
paddle/fluid/platform/dynload/cudnn.h
paddle/fluid/platform/dynload/cudnn.h
+9
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+24
-10
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
+31
-7
未找到文件。
paddle/fluid/framework/grad_op_desc_maker.h
浏览文件 @
5e813b53
...
@@ -141,6 +141,10 @@ class GradOpDescMakerBase {
...
@@ -141,6 +141,10 @@ class GradOpDescMakerBase {
return
(
fwd_op_
.
Inputs
().
count
(
name
)
>
0
);
return
(
fwd_op_
.
Inputs
().
count
(
name
)
>
0
);
}
}
bool
HasOutput
(
const
std
::
string
&
name
)
const
{
return
(
fwd_op_
.
Outputs
().
count
(
name
)
>
0
);
}
private:
private:
const
OpDesc
&
fwd_op_
;
const
OpDesc
&
fwd_op_
;
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set_
;
const
std
::
unordered_set
<
std
::
string
>&
no_grad_set_
;
...
...
paddle/fluid/imperative/dygraph_grad_maker.h
浏览文件 @
5e813b53
...
@@ -107,6 +107,12 @@ class GradOpBaseMakerBase {
...
@@ -107,6 +107,12 @@ class GradOpBaseMakerBase {
return
it
!=
var_base_map_in_
.
end
();
return
it
!=
var_base_map_in_
.
end
();
}
}
bool
HasOutput
(
const
std
::
string
name
)
const
{
auto
it
=
var_base_map_out_
.
find
(
name
);
return
it
!=
var_base_map_out_
.
end
();
}
private:
private:
std
::
vector
<
std
::
shared_ptr
<
VarBase
>>
GetVarBaseList
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
shared_ptr
<
VarBase
>>
GetVarBaseList
(
const
std
::
string
&
name
,
bool
is_grad
,
bool
is_grad
,
...
...
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
5e813b53
...
@@ -25,27 +25,42 @@ namespace paddle {
...
@@ -25,27 +25,42 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
void
BatchNormOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
void
BatchNormOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of ConvOp should not be null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Scale"
),
platform
::
errors
::
InvalidArgument
(
"Input(Scale) of ConvOp should not be null."
);
"Input(X) of BatchNormOp should not be null."
));
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Bias"
),
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Scale"
),
true
,
"Input(Bias) of ConvOp should not be null."
);
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Mean"
),
"Input(Scale) of BatchNormOp should not be null."
));
"Input(Mean) of ConvOp should not be null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Bias"
),
true
,
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Variance"
),
platform
::
errors
::
InvalidArgument
(
"Input(Variance) of ConvOp should not be null."
);
"Input(Bias) of BatchNormOp should not be null."
));
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Mean"
),
true
,
"Output(Y) of ConvOp should not be null."
);
platform
::
errors
::
InvalidArgument
(
"Input(Mean) of BatchNormOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Variance"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Variance) of BatchNormOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Y"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Output(Y) of BatchNormOp should not be null."
));
bool
is_test
=
ctx
->
Attrs
().
Get
<
bool
>
(
"is_test"
);
bool
is_test
=
ctx
->
Attrs
().
Get
<
bool
>
(
"is_test"
);
if
(
!
is_test
)
{
if
(
!
is_test
)
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"MeanOut"
),
PADDLE_ENFORCE_EQ
(
"Output(MeanOut) of ConvOp should not be null."
);
ctx
->
HasOutput
(
"MeanOut"
),
true
,
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"VarianceOut"
),
platform
::
errors
::
InvalidArgument
(
"Output(VarianceOut) of ConvOp should not be null."
);
"Output(MeanOut) of BatchNormOp should not be null."
));
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"SavedMean"
),
PADDLE_ENFORCE_EQ
(
"Output(SavedMean) of ConvOp should not be null."
);
ctx
->
HasOutput
(
"VarianceOut"
),
true
,
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"SavedVariance"
),
platform
::
errors
::
InvalidArgument
(
"Output(SavedVariance) of ConvOp should not be null."
);
"Output(VarianceOut) of BatchNormOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"SavedMean"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Output(SavedMean) of BatchNormOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"SavedVariance"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Output(SavedVariance) of BatchNormOp should not be null."
));
}
}
// make sure Mean/MeanOut and Variance/VarianceOut share memory in Python
// make sure Mean/MeanOut and Variance/VarianceOut share memory in Python
...
@@ -200,6 +215,10 @@ void BatchNormOpMaker::Make() {
...
@@ -200,6 +215,10 @@ void BatchNormOpMaker::Make() {
"Variance of the current mini batch, "
"Variance of the current mini batch, "
"will apply to output when training"
)
"will apply to output when training"
)
.
AsIntermediate
();
.
AsIntermediate
();
AddOutput
(
"ReserveSpace"
,
"Reserve GPU space for triggering the new semi-persistent "
"NHWC kernel"
)
.
AsDispensable
();
AddAttr
<
bool
>
(
"use_mkldnn"
,
AddAttr
<
bool
>
(
"use_mkldnn"
,
"(bool, default false) Only used in mkldnn kernel"
)
"(bool, default false) Only used in mkldnn kernel"
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
...
@@ -643,6 +662,9 @@ std::unique_ptr<T> BatchNormGradMaker<T>::Apply() const {
...
@@ -643,6 +662,9 @@ std::unique_ptr<T> BatchNormGradMaker<T>::Apply() const {
op
->
SetInput
(
"Bias"
,
this
->
Input
(
"Bias"
));
op
->
SetInput
(
"Bias"
,
this
->
Input
(
"Bias"
));
op
->
SetInput
(
"SavedMean"
,
this
->
Output
(
"SavedMean"
));
op
->
SetInput
(
"SavedMean"
,
this
->
Output
(
"SavedMean"
));
op
->
SetInput
(
"SavedVariance"
,
this
->
Output
(
"SavedVariance"
));
op
->
SetInput
(
"SavedVariance"
,
this
->
Output
(
"SavedVariance"
));
if
(
this
->
HasOutput
(
"ReserveSpace"
))
{
op
->
SetInput
(
"ReserveSpace"
,
this
->
Output
(
"ReserveSpace"
));
}
// used when setting use_global_stats True during training
// used when setting use_global_stats True during training
if
(
boost
::
get
<
bool
>
(
this
->
GetAttr
(
"use_global_stats"
)))
{
if
(
boost
::
get
<
bool
>
(
this
->
GetAttr
(
"use_global_stats"
)))
{
...
...
paddle/fluid/operators/batch_norm_op.cu
浏览文件 @
5e813b53
...
@@ -56,12 +56,39 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
...
@@ -56,12 +56,39 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
const
auto
&
x_dims
=
x
->
dims
();
const
auto
&
x_dims
=
x
->
dims
();
PADDLE_ENFORCE
(
x_dims
.
size
()
>=
2
&&
x_dims
.
size
()
<=
5
,
PADDLE_ENFORCE
(
x_dims
.
size
()
>=
2
&&
x_dims
.
size
()
<=
5
,
"The Input dim size should be between 2 and 5"
);
"The Input dim size should be between 2 and 5"
);
int
N
,
C
,
H
,
W
,
D
;
ExtractNCWHD
(
x_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
N
,
C
,
H
,
W
,
D
;
ExtractNCWHD
(
x_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
const
bool
fast_nhwc_batch_norm
=
is_test
||
(
dtype
==
CUDNN_DATA_HALF
&&
FLAGS_cudnn_batchnorm_spatial_persistent
);
auto
compute_format
=
fast_nhwc_batch_norm
&&
data_layout
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
:
DataLayout
::
kNCHW
;
Tensor
transformed_x
(
x
->
type
());
Tensor
transformed_y
(
y
->
type
());
if
(
data_layout
==
DataLayout
::
kNHWC
&&
compute_format
==
DataLayout
::
kNCHW
&&
x_dims
.
size
()
>
2
)
{
VLOG
(
3
)
<<
"Transform input tensor from NHWC to NCHW."
;
ResizeToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
x
,
&
transformed_x
);
TransToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
x
,
&
transformed_x
);
ResizeToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
y
,
&
transformed_y
);
}
else
{
transformed_x
.
ShareDataWith
(
*
x
);
transformed_y
.
ShareDataWith
(
*
y
);
}
// ------------------- cudnn descriptors ---------------------
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
bn_param_desc_
;
cudnnTensorDescriptor_t
bn_param_desc_
;
...
@@ -90,7 +117,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
...
@@ -90,7 +117,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
VLOG
(
3
)
<<
"Setting descriptors."
;
VLOG
(
3
)
<<
"Setting descriptors."
;
std
::
vector
<
int
>
dims
;
std
::
vector
<
int
>
dims
;
std
::
vector
<
int
>
strides
;
std
::
vector
<
int
>
strides
;
if
(
data_layou
t
==
DataLayout
::
kNCHW
)
{
if
(
compute_forma
t
==
DataLayout
::
kNCHW
)
{
dims
=
{
N
,
C
,
H
,
W
,
D
};
dims
=
{
N
,
C
,
H
,
W
,
D
};
strides
=
{
C
*
H
*
W
*
D
,
H
*
W
*
D
,
W
*
D
,
D
,
1
};
strides
=
{
C
*
H
*
W
*
D
,
H
*
W
*
D
,
W
*
D
,
D
,
1
};
}
else
{
}
else
{
...
@@ -126,8 +153,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
...
@@ -126,8 +153,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
handle
,
handle
,
// Note: PERSISTENT not implemented for inference
// Note: PERSISTENT not implemented for inference
CUDNN_BATCHNORM_SPATIAL
,
CudnnDataType
<
T
>::
kOne
(),
CUDNN_BATCHNORM_SPATIAL
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
x
->
template
data
<
T
>(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
data_desc_
,
y
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
transformed_x
.
template
data
<
T
>(),
data_desc_
,
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
est_mean
->
template
data
<
BatchNormParamType
<
T
>
>
(),
...
@@ -167,23 +195,102 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
...
@@ -167,23 +195,102 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
}
else
{
}
else
{
double
this_factor
=
1.
-
momentum
;
double
this_factor
=
1.
-
momentum
;
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnBatchNormalizationForwardTraining
(
bool
called
=
false
;
handle
,
mode_
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
#if CUDNN_VERSION_MIN(7, 4, 1)
data_desc_
,
x
->
template
data
<
T
>(),
data_desc_
,
if
(
compute_format
==
DataLayout
::
kNHWC
)
{
y
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
called
=
true
;
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
size_t
workspace_size
=
0
;
size_t
reserve_space_size
=
0
;
void
*
reserve_space_ptr
=
nullptr
;
void
*
workspace_ptr
=
nullptr
;
Tensor
workspace_tensor
;
// Create reserve space and workspace for batch norm.
// Create tensor for each batchnorm op, it will be used in the
// backward. Thus this tensor shouldn't be temp.
auto
*
reserve_space
=
ctx
.
Output
<
Tensor
>
(
"ReserveSpace"
);
PADDLE_ENFORCE_NOT_NULL
(
reserve_space
,
platform
::
errors
::
NotFound
(
"The argument ReserveSpace of batch_norm op is not found."
));
// --------------- cudnn batchnorm workspace ---------------
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize
(
/*handle=*/
handle
,
/*mode=*/
mode_
,
/*bnIps=*/
CUDNN_BATCHNORM_OPS_BN
,
/*xDesc=*/
data_desc_
,
/*zDesc=*/
nullptr
,
/*yDesc=*/
data_desc_
,
/*bnScaleBiasMeanVarDesc=*/
bn_param_desc_
,
/*activationDesc=*/
nullptr
,
/*sizeInBytes=*/
&
workspace_size
));
// -------------- cudnn batchnorm reserve space --------------
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetBatchNormalizationTrainingExReserveSpaceSize
(
/*handle=*/
handle
,
/*mode=*/
mode_
,
/*bnOps=*/
CUDNN_BATCHNORM_OPS_BN
,
/*activationDesc=*/
nullptr
,
/*xDesc=*/
data_desc_
,
/*sizeInBytes=*/
&
reserve_space_size
));
reserve_space_ptr
=
reserve_space
->
mutable_data
(
ctx
.
GetPlace
(),
transformed_x
.
type
(),
reserve_space_size
);
workspace_ptr
=
workspace_tensor
.
mutable_data
(
ctx
.
GetPlace
(),
transformed_x
.
type
(),
workspace_size
);
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnBatchNormalizationForwardTrainingEx
(
handle
,
mode_
,
CUDNN_BATCHNORM_OPS_BN
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
transformed_x
.
template
data
<
T
>(),
nullptr
,
nullptr
,
data_desc_
,
transformed_y
.
template
data
<
T
>(),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
this_factor
,
mean_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
variance_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
nullptr
,
workspace_ptr
,
workspace_size
,
reserve_space_ptr
,
reserve_space_size
));
}
#endif
if
(
!
called
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnBatchNormalizationForwardTraining
(
handle
,
mode_
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
transformed_x
.
template
data
<
T
>(),
data_desc_
,
transformed_y
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
this_factor
,
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
this_factor
,
mean_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
mean_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
ctx
.
GetPlace
()),
variance_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
variance_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
ctx
.
GetPlace
()),
epsilon
,
saved_mean
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
epsilon
,
saved_mean
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
ctx
.
GetPlace
()),
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
())));
ctx
.
GetPlace
())));
}
}
}
}
}
if
(
data_layout
==
DataLayout
::
kNHWC
&&
compute_format
==
DataLayout
::
kNCHW
&&
x_dims
.
size
()
>
2
)
{
VLOG
(
3
)
<<
"Transform batchnorm output from NCHW to NHWC"
;
TransToChannelLast
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
&
transformed_y
,
y
);
}
// clean when exit.
// clean when exit.
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
CUDNN_ENFORCE
(
CUDNN_ENFORCE
(
...
@@ -337,9 +444,41 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -337,9 +444,41 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE_EQ
(
scale
->
dims
().
size
(),
1UL
);
PADDLE_ENFORCE_EQ
(
scale
->
dims
().
size
(),
1UL
);
PADDLE_ENFORCE_EQ
(
scale
->
dims
()[
0
],
C
);
PADDLE_ENFORCE_EQ
(
scale
->
dims
()[
0
],
C
);
auto
dtype
=
platform
::
CudnnDataType
<
T
>::
type
;
const
auto
*
reserve_space
=
ctx
.
Input
<
Tensor
>
(
"ReserveSpace"
);
const
bool
fast_nhwc_batch_norm
=
dtype
==
CUDNN_DATA_HALF
&&
FLAGS_cudnn_batchnorm_spatial_persistent
&&
reserve_space
!=
nullptr
;
auto
compute_format
=
fast_nhwc_batch_norm
&&
data_layout
==
DataLayout
::
kNHWC
?
DataLayout
::
kNHWC
:
DataLayout
::
kNCHW
;
Tensor
transformed_x
(
x
->
type
());
Tensor
transformed_d_y
(
d_y
->
type
());
Tensor
transformed_d_x
(
d_x
->
type
());
if
(
data_layout
==
DataLayout
::
kNHWC
&&
compute_format
==
DataLayout
::
kNCHW
)
{
VLOG
(
3
)
<<
"Transform input tensor from NHWC to NCHW."
;
ResizeToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
x
,
&
transformed_x
);
TransToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
x
,
&
transformed_x
);
ResizeToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
d_y
,
&
transformed_d_y
);
TransToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
d_y
,
&
transformed_d_y
);
ResizeToChannelFirst
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
d_x
,
&
transformed_d_x
);
}
else
{
transformed_x
.
ShareDataWith
(
*
x
);
transformed_d_y
.
ShareDataWith
(
*
d_y
);
transformed_d_x
.
ShareDataWith
(
*
d_x
);
}
std
::
vector
<
int
>
dims
;
std
::
vector
<
int
>
dims
;
std
::
vector
<
int
>
strides
;
std
::
vector
<
int
>
strides
;
if
(
data_layou
t
==
DataLayout
::
kNCHW
)
{
if
(
compute_forma
t
==
DataLayout
::
kNCHW
)
{
dims
=
{
N
,
C
,
H
,
W
,
D
};
dims
=
{
N
,
C
,
H
,
W
,
D
};
strides
=
{
C
*
H
*
W
*
D
,
H
*
W
*
D
,
W
*
D
,
D
,
1
};
strides
=
{
C
*
H
*
W
*
D
,
H
*
W
*
D
,
W
*
D
,
D
,
1
};
}
else
{
}
else
{
...
@@ -348,7 +487,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -348,7 +487,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
}
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
const
int
num
=
x
->
numel
();
const
int
num
=
transformed_x
.
numel
();
const
int
block
=
512
;
const
int
block
=
512
;
int
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
int
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
const
int
max_blocks
=
std
::
max
(
max_threads
/
block
,
1
);
const
int
max_blocks
=
std
::
max
(
max_threads
/
block
,
1
);
...
@@ -404,20 +543,95 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -404,20 +543,95 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
saved_var
->
template
data
<
BatchNormParamType
<
T
>
>
();
saved_var
->
template
data
<
BatchNormParamType
<
T
>
>
();
if
(
d_scale
&&
d_bias
)
{
if
(
d_scale
&&
d_bias
)
{
bool
called
=
false
;
#if CUDNN_VERSION_MIN(7, 4, 1)
if
(
compute_format
==
DataLayout
::
kNHWC
)
{
called
=
true
;
size_t
workspace_size
=
0
;
void
*
workspace_ptr
=
nullptr
;
Tensor
workspace_tensor
;
auto
reserve_space_size
=
reserve_space
->
memory_size
();
// --------------- cudnn batchnorm workspace ---------------
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnGetBatchNormalizationBackwardExWorkspaceSize
(
/*handle=*/
dev_ctx
.
cudnn_handle
(),
/*mode=*/
mode_
,
/*bnIps=*/
CUDNN_BATCHNORM_OPS_BN
,
/*xDesc=*/
data_desc_
,
/*yDesc=*/
data_desc_
,
/*dyDesc=*/
data_desc_
,
/*dzDesc=*/
nullptr
,
/*dxDesc=*/
data_desc_
,
/*bnScaleBiasMeanVarDesc=*/
bn_param_desc_
,
/*activationDesc=*/
nullptr
,
/*sizeInBytes=*/
&
workspace_size
));
workspace_ptr
=
workspace_tensor
.
mutable_data
(
ctx
.
GetPlace
(),
transformed_x
.
type
(),
workspace_size
);
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnBatchNormalizationBackwardEx
(
/*handle=*/
dev_ctx
.
cudnn_handle
(),
/*mode=*/
mode_
,
/*bnOps=*/
CUDNN_BATCHNORM_OPS_BN
,
/*alphaDataDiff=*/
CudnnDataType
<
T
>::
kOne
(),
/*betaDataDiff=*/
CudnnDataType
<
T
>::
kZero
(),
/*alphaParamDiff=*/
CudnnDataType
<
T
>::
kOne
(),
/*betaParamDiff=*/
CudnnDataType
<
T
>::
kZero
(),
/*xDesc=*/
data_desc_
,
/*xData=*/
transformed_x
.
template
data
<
T
>(),
/*yDesc=*/
nullptr
,
/*yData=*/
nullptr
,
/*dyDesc=*/
data_desc_
,
/*dyData=*/
transformed_d_y
.
template
data
<
T
>(),
/*dzDesc=*/
nullptr
,
/*dzData=*/
nullptr
,
/*dxDesc=*/
data_desc_
,
/*dxData=*/
transformed_d_x
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
/*dBnScaleBiasDesc=*/
bn_param_desc_
,
/*bnScaleData=*/
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
/*bnBiasData=*/
nullptr
,
/*dBnScaleData=*/
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
/*dBnBiasData=*/
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
/*epsilon=*/
epsilon
,
/*savedMean=*/
saved_mean_data
,
/*savedInvVariance=*/
saved_var_data
,
/*activationDesc=*/
nullptr
,
/*workspace=*/
workspace_ptr
,
/*workSpaceSizeInBytes=*/
workspace_size
,
/*reserveSpace=*/
const_cast
<
T
*>
(
reserve_space
->
template
data
<
T
>()),
/*reserveSpaceSizeInBytes=*/
reserve_space_size
));
}
#endif
if
(
!
called
)
{
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnBatchNormalizationBackward
(
CUDNN_ENFORCE
(
platform
::
dynload
::
cudnnBatchNormalizationBackward
(
dev_ctx
.
cudnn_handle
(),
mode_
,
CudnnDataType
<
T
>::
kOne
(),
dev_ctx
.
cudnn_handle
(),
mode_
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
x
->
template
data
<
T
>(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
data_desc_
,
d_y
->
template
data
<
T
>(),
data_desc_
,
transformed_x
.
template
data
<
T
>(),
data_desc_
,
d_x
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
transformed_d_y
.
template
data
<
T
>(),
data_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
transformed_d_x
.
template
mutable_data
<
T
>(
ctx
.
GetPlace
()),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
d_scale
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
ctx
.
GetPlace
()),
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
d_bias
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
ctx
.
GetPlace
()),
epsilon
,
saved_mean_data
,
saved_var_data
));
epsilon
,
saved_mean_data
,
saved_var_data
));
}
if
(
data_layout
==
DataLayout
::
kNHWC
&&
compute_format
==
DataLayout
::
kNCHW
)
{
VLOG
(
3
)
<<
"Transform batchnorm output from NCHW to NHWC"
;
TransToChannelLast
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
(
ctx
,
&
transformed_d_x
,
d_x
);
}
}
else
{
}
else
{
if
(
data_layout
==
framework
::
DataLayout
::
kNCHW
)
{
if
(
compute_format
==
DataLayout
::
kNCHW
)
{
if
(
d_x
)
{
if
(
d_x
)
{
BNBackwardData
<
T
,
block
,
framework
::
DataLayout
::
kNCHW
><<<
BNBackwardData
<
T
,
block
,
framework
::
DataLayout
::
kNCHW
><<<
grid2
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
grid2
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
...
@@ -450,7 +664,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
...
@@ -450,7 +664,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
const
auto
*
running_var_data
=
const
auto
*
running_var_data
=
running_var
->
template
data
<
BatchNormParamType
<
T
>
>
();
running_var
->
template
data
<
BatchNormParamType
<
T
>
>
();
if
(
data_layout
==
framework
::
DataLayout
::
kNCHW
)
{
if
(
compute_format
==
DataLayout
::
kNCHW
)
{
if
(
d_x
)
{
if
(
d_x
)
{
KeBNBackwardData
<
T
,
framework
::
DataLayout
::
kNCHW
><<<
KeBNBackwardData
<
T
,
framework
::
DataLayout
::
kNCHW
><<<
grid1
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
grid1
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
...
...
paddle/fluid/operators/batch_norm_op.h
浏览文件 @
5e813b53
...
@@ -16,8 +16,10 @@ limitations under the License. */
...
@@ -16,8 +16,10 @@ limitations under the License. */
#include <memory>
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/norm_utils.h"
#include "paddle/fluid/operators/norm_utils.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -39,24 +41,109 @@ template <typename T>
...
@@ -39,24 +41,109 @@ template <typename T>
using
ConstEigenVectorArrayMap
=
using
ConstEigenVectorArrayMap
=
Eigen
::
Map
<
const
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
1
>>
;
Eigen
::
Map
<
const
Eigen
::
Array
<
T
,
Eigen
::
Dynamic
,
1
>>
;
template
<
typename
DeviceContext
,
typename
T
>
inline
void
ResizeToChannelFirst
(
const
framework
::
ExecutionContext
&
context
,
const
Tensor
*
input
,
Tensor
*
transformed_input
)
{
int
dim
=
input
->
dims
().
size
()
-
2
;
if
(
dim
==
3
)
{
// input
transformed_input
->
Resize
(
input
->
dims
());
auto
in_dims_vec
=
framework
::
vectorize
(
input
->
dims
());
in_dims_vec
[
1
]
=
input
->
dims
()[
4
];
in_dims_vec
[
2
]
=
input
->
dims
()[
1
];
in_dims_vec
[
3
]
=
input
->
dims
()[
2
];
in_dims_vec
[
4
]
=
input
->
dims
()[
3
];
transformed_input
->
Resize
(
framework
::
make_ddim
(
in_dims_vec
));
transformed_input
->
mutable_data
<
T
>
(
context
.
GetPlace
());
}
else
if
(
dim
==
2
)
{
// input
transformed_input
->
Resize
(
input
->
dims
());
auto
in_dims_vec
=
framework
::
vectorize
(
input
->
dims
());
in_dims_vec
[
1
]
=
input
->
dims
()[
3
];
in_dims_vec
[
2
]
=
input
->
dims
()[
1
];
in_dims_vec
[
3
]
=
input
->
dims
()[
2
];
transformed_input
->
Resize
(
framework
::
make_ddim
(
in_dims_vec
));
transformed_input
->
mutable_data
<
T
>
(
context
.
GetPlace
());
}
else
if
(
dim
==
1
)
{
transformed_input
->
Resize
(
input
->
dims
());
auto
in_dims_vec
=
framework
::
vectorize
(
input
->
dims
());
in_dims_vec
[
1
]
=
input
->
dims
()[
2
];
in_dims_vec
[
2
]
=
input
->
dims
()[
1
];
transformed_input
->
Resize
(
framework
::
make_ddim
(
in_dims_vec
));
transformed_input
->
mutable_data
<
T
>
(
context
.
GetPlace
());
}
}
template
<
typename
DeviceContext
,
typename
T
>
inline
void
TransToChannelFirst
(
const
framework
::
ExecutionContext
&
context
,
const
Tensor
*
input
,
Tensor
*
transformed_input
)
{
int
dim
=
input
->
dims
().
size
()
-
2
;
if
(
dim
==
3
)
{
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
std
::
vector
<
int
>
axis
{
0
,
4
,
1
,
2
,
3
};
math
::
Transpose
<
DeviceContext
,
T
,
5
>
trans5
;
trans5
(
dev_ctx
,
*
input
,
transformed_input
,
axis
);
}
else
if
(
dim
==
2
)
{
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
std
::
vector
<
int
>
axis
{
0
,
3
,
1
,
2
};
math
::
Transpose
<
DeviceContext
,
T
,
4
>
trans4
;
trans4
(
dev_ctx
,
*
input
,
transformed_input
,
axis
);
}
else
if
(
dim
==
1
)
{
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
std
::
vector
<
int
>
axis
{
0
,
2
,
1
};
math
::
Transpose
<
DeviceContext
,
T
,
3
>
trans3
;
trans3
(
dev_ctx
,
*
input
,
transformed_input
,
axis
);
}
}
template
<
typename
DeviceContext
,
typename
T
>
inline
void
TransToChannelLast
(
const
framework
::
ExecutionContext
&
context
,
const
Tensor
*
input
,
Tensor
*
transformed_input
)
{
int
dim
=
input
->
dims
().
size
()
-
2
;
if
(
dim
==
3
)
{
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
std
::
vector
<
int
>
axis
{
0
,
2
,
3
,
4
,
1
};
math
::
Transpose
<
DeviceContext
,
T
,
5
>
trans5
;
trans5
(
dev_ctx
,
*
input
,
transformed_input
,
axis
);
}
else
if
(
dim
==
2
)
{
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
std
::
vector
<
int
>
axis
{
0
,
2
,
3
,
1
};
math
::
Transpose
<
DeviceContext
,
T
,
4
>
trans4
;
trans4
(
dev_ctx
,
*
input
,
transformed_input
,
axis
);
}
else
if
(
dim
==
1
)
{
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
std
::
vector
<
int
>
axis
{
0
,
2
,
1
};
math
::
Transpose
<
DeviceContext
,
T
,
3
>
trans3
;
trans3
(
dev_ctx
,
*
input
,
transformed_input
,
axis
);
}
}
class
BatchNormOp
:
public
framework
::
OperatorWithKernel
{
class
BatchNormOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
};
class
BatchNormGradOp
:
public
framework
::
OperatorWithKernel
{
class
BatchNormGradOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
};
class
BatchNormOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
BatchNormOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
@@ -85,13 +172,13 @@ class BatchNormOpInferVarType
...
@@ -85,13 +172,13 @@ class BatchNormOpInferVarType
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
BatchNormKernel
:
public
framework
::
OpKernel
<
T
>
{
class
BatchNormKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
BatchNormGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
BatchNormGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/platform/dynload/cudnn.cc
浏览文件 @
5e813b53
...
@@ -46,6 +46,10 @@ CUDNN_DNN_ROUTINE_EACH_R6(DEFINE_WRAP);
...
@@ -46,6 +46,10 @@ CUDNN_DNN_ROUTINE_EACH_R6(DEFINE_WRAP);
CUDNN_DNN_ROUTINE_EACH_R7
(
DEFINE_WRAP
);
CUDNN_DNN_ROUTINE_EACH_R7
(
DEFINE_WRAP
);
#endif
#endif
#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R7
CUDNN_DNN_ROUTINE_EACH_AFTER_R7
(
DEFINE_WRAP
);
#endif
#ifdef PADDLE_USE_DSO
#ifdef PADDLE_USE_DSO
bool
HasCUDNN
()
{
bool
HasCUDNN
()
{
std
::
call_once
(
cudnn_dso_flag
,
std
::
call_once
(
cudnn_dso_flag
,
...
...
paddle/fluid/platform/dynload/cudnn.h
浏览文件 @
5e813b53
...
@@ -189,6 +189,15 @@ CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
...
@@ -189,6 +189,15 @@ CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
CUDNN_DNN_ROUTINE_EACH_R7
(
DECLARE_DYNAMIC_LOAD_CUDNN_WRAP
)
CUDNN_DNN_ROUTINE_EACH_R7
(
DECLARE_DYNAMIC_LOAD_CUDNN_WRAP
)
#endif
#endif
#if CUDNN_VERSION >= 7401
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R7(__macro) \
__macro(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize); \
__macro(cudnnBatchNormalizationForwardTrainingEx); \
__macro(cudnnGetBatchNormalizationBackwardExWorkspaceSize); \
__macro(cudnnBatchNormalizationBackwardEx); \
__macro(cudnnGetBatchNormalizationTrainingExReserveSpaceSize);
CUDNN_DNN_ROUTINE_EACH_AFTER_R7
(
DECLARE_DYNAMIC_LOAD_CUDNN_WRAP
)
#endif
}
// namespace dynload
}
// namespace dynload
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
python/paddle/fluid/layers/nn.py
浏览文件 @
5e813b53
...
@@ -2523,6 +2523,13 @@ def batch_norm(input,
...
@@ -2523,6 +2523,13 @@ def batch_norm(input,
check_type_and_dtype(input, 'input', Variable,
check_type_and_dtype(input, 'input', Variable,
['float16', 'float32', 'float64'], 'batch_norm')
['float16', 'float32', 'float64'], 'batch_norm')
dtype = helper.input_dtype()
dtype = helper.input_dtype()
has_reserve_space = False
if data_layout == 'NHWC':
flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent')
if flag is not None and flag.lower() in ['true', '1']:
has_reserve_space = True
# use fp32 for bn parameter
# use fp32 for bn parameter
if dtype == core.VarDesc.VarType.FP16:
if dtype == core.VarDesc.VarType.FP16:
dtype = core.VarDesc.VarType.FP32
dtype = core.VarDesc.VarType.FP32
...
@@ -2577,6 +2584,11 @@ def batch_norm(input,
...
@@ -2577,6 +2584,11 @@ def batch_norm(input,
saved_variance = helper.create_variable_for_type_inference(
saved_variance = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
dtype=dtype, stop_gradient=True)
reserve_space = None
if has_reserve_space:
reserve_space = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.FP16, stop_gradient=True)
batch_norm_out = input if in_place else helper.create_variable_for_type_inference(
batch_norm_out = input if in_place else helper.create_variable_for_type_inference(
dtype)
dtype)
...
@@ -2599,17 +2611,19 @@ def batch_norm(input,
...
@@ -2599,17 +2611,19 @@ def batch_norm(input,
inputs['MomemtumTensor'] = momentum
inputs['MomemtumTensor'] = momentum
else:
else:
attrs['momentum'] = momentum
attrs['momentum'] = momentum
helper.append_op(
type="batch_norm",
outputs = {
inputs=inputs,
outputs={
"Y": batch_norm_out,
"Y": batch_norm_out,
"MeanOut": mean_out,
"MeanOut": mean_out,
"VarianceOut": variance_out,
"VarianceOut": variance_out,
"SavedMean": saved_mean,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
"SavedVariance": saved_variance
},
}
attrs=attrs)
if reserve_space is not None:
outputs["ReserveSpace"] = reserve_space
helper.append_op(
type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs)
return helper.append_activation(batch_norm_out)
return helper.append_activation(batch_norm_out)
...
...
python/paddle/fluid/tests/unittests/test_batch_norm_op.py
浏览文件 @
5e813b53
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
os
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid.core
as
core
import
paddle.fluid.core
as
core
...
@@ -413,16 +414,28 @@ class TestBatchNormOpTraining(unittest.TestCase):
...
@@ -413,16 +414,28 @@ class TestBatchNormOpTraining(unittest.TestCase):
inputs
[
'MomentumTensor'
]
=
block
.
var
(
'momentum_var'
)
inputs
[
'MomentumTensor'
]
=
block
.
var
(
'momentum_var'
)
else
:
else
:
attrs
[
'momentum'
]
=
momentum
attrs
[
'momentum'
]
=
momentum
bn_op
=
block
.
append_op
(
type
=
"batch_norm"
,
outputs
=
{
inputs
=
inputs
,
outputs
=
{
"Y"
:
block
.
var
(
'y'
),
"Y"
:
block
.
var
(
'y'
),
"MeanOut"
:
block
.
var
(
'mean'
),
# share memory
"MeanOut"
:
block
.
var
(
'mean'
),
# share memory
"VarianceOut"
:
block
.
var
(
'variance'
),
# share memory
"VarianceOut"
:
block
.
var
(
'variance'
),
# share memory
"SavedMean"
:
block
.
var
(
'saved_mean'
),
"SavedMean"
:
block
.
var
(
'saved_mean'
),
"SavedVariance"
:
block
.
var
(
'saved_variance'
)
"SavedVariance"
:
block
.
var
(
'saved_variance'
)
},
}
has_reserve_space
=
False
if
data_format
==
'NHWC'
:
flag
=
os
.
environ
.
get
(
'FLAGS_cudnn_batchnorm_spatial_persistent'
)
if
flag
is
not
None
and
flag
.
lower
()
in
[
'true'
,
'1'
]:
has_reserve_space
=
True
if
has_reserve_space
:
block
.
create_var
(
name
=
"reserve_space"
,
dtype
=
'float16'
)
outputs
[
"ReserveSpace"
]
=
block
.
var
(
'reserve_space'
)
del
os
.
environ
[
'FLAGS_cudnn_batchnorm_spatial_persistent'
]
bn_op
=
block
.
append_op
(
type
=
"batch_norm"
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
attrs
=
attrs
)
block
.
create_var
(
name
=
'y@GRAD'
,
dtype
=
'float32'
,
shape
=
y
.
shape
)
block
.
create_var
(
name
=
'y@GRAD'
,
dtype
=
'float32'
,
shape
=
y
.
shape
)
...
@@ -479,6 +492,17 @@ class TestBatchNormOpTrainingCase1(TestBatchNormOpTraining):
...
@@ -479,6 +492,17 @@ class TestBatchNormOpTrainingCase1(TestBatchNormOpTraining):
self
.
fetch_list
=
[
'y'
,
'mean'
,
'variance'
,
'x@GRAD'
]
self
.
fetch_list
=
[
'y'
,
'mean'
,
'variance'
,
'x@GRAD'
]
class
TestBatchNormOpTrainingCase2
(
TestBatchNormOpTraining
):
def
init_test_case
(
self
):
self
.
use_global_stats
=
False
self
.
no_grad_set
=
set
()
self
.
fetch_list
=
[
'y'
,
'mean'
,
'variance'
,
'saved_mean'
,
'saved_variance'
,
'x@GRAD'
,
'scale@GRAD'
,
'bias@GRAD'
]
os
.
environ
[
'FLAGS_cudnn_batchnorm_spatial_persistent'
]
=
"1"
class
TestBatchNormOpTrainingMomentumVariable
(
TestBatchNormOpTraining
):
class
TestBatchNormOpTrainingMomentumVariable
(
TestBatchNormOpTraining
):
def
init_test_case
(
self
):
def
init_test_case
(
self
):
self
.
use_momentum_variable
=
True
self
.
use_momentum_variable
=
True
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录