Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
906e7f92
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
906e7f92
编写于
9月 23, 2020
作者:
Z
Zhang Ting
提交者:
GitHub
9月 23, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fuse_bn_act op (#27230)
* add fused_bn_add_relu op
上级
5034d181
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
1120 addition
and
6 deletion
+1120
-6
cmake/operators.cmake
cmake/operators.cmake
+2
-1
paddle/fluid/operators/fused/CMakeLists.txt
paddle/fluid/operators/fused/CMakeLists.txt
+7
-1
paddle/fluid/operators/fused/fused_bn_add_activation_op.cc
paddle/fluid/operators/fused/fused_bn_add_activation_op.cc
+255
-0
paddle/fluid/operators/fused/fused_bn_add_activation_op.cu
paddle/fluid/operators/fused/fused_bn_add_activation_op.cu
+338
-0
paddle/fluid/operators/fused/fused_bn_add_activation_op.h
paddle/fluid/operators/fused/fused_bn_add_activation_op.h
+106
-0
python/paddle/fluid/contrib/layers/nn.py
python/paddle/fluid/contrib/layers/nn.py
+190
-1
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
+1
-0
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
+6
-3
python/paddle/fluid/tests/unittests/test_fused_bn_add_act.py
python/paddle/fluid/tests/unittests/test_fused_bn_add_act.py
+215
-0
未找到文件。
cmake/operators.cmake
浏览文件 @
906e7f92
...
...
@@ -127,7 +127,8 @@ function(op_library TARGET)
"tensor_array_read_write_op"
"tensorrt_engine_op"
"conv_fusion_op"
"fusion_transpose_flatten_concat_op"
"fusion_conv_inception_op"
"sync_batch_norm_op"
"dgc_op"
"fused_fc_elementwise_layernorm_op"
"multihead_matmul_op"
"fusion_group_op"
"fused_bn_activation_op"
"fused_embedding_eltwise_layernorm_op"
"fusion_gru_op"
)
"multihead_matmul_op"
"fusion_group_op"
"fused_bn_activation_op"
"fused_embedding_eltwise_layernorm_op"
"fusion_gru_op"
"fused_bn_add_activation_op"
)
if
(
"
${
TARGET
}
"
STREQUAL
"
${
manual_pybind_op
}
"
)
set
(
pybind_flag 1
)
endif
()
...
...
paddle/fluid/operators/fused/CMakeLists.txt
浏览文件 @
906e7f92
...
...
@@ -8,7 +8,8 @@ register_operators(EXCLUDES
multihead_matmul_op
fused_embedding_eltwise_layernorm_op
fusion_group_op
fusion_gru_op
)
fusion_gru_op
fused_bn_add_activation_op
)
# fusion_gru_op does not have CUDA kernel
op_library
(
fusion_gru_op
)
...
...
@@ -47,4 +48,9 @@ if (WITH_GPU)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(fusion_group);
\n
"
)
cc_test
(
test_fusion_group_op SRCS fusion_group_op_test.cc DEPS fusion_group_op
)
endif
()
# fused_bn_add_activation
if
(
NOT
${
CUDNN_VERSION
}
VERSION_LESS 7401
)
op_library
(
fused_bn_add_activation_op
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(fused_bn_add_activation);
\n
"
)
endif
()
endif
()
paddle/fluid/operators/fused/fused_bn_add_activation_op.cc
0 → 100644
浏览文件 @
906e7f92
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused/fused_bn_add_activation_op.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
void
FusedBatchNormAddActOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedBatchNormAddActOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Z"
),
"Input"
,
"Z"
,
"FusedBatchNormAddActOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scale"
),
"Input"
,
"Scale"
,
"FusedBatchNormAddActOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Bias"
),
"Input"
,
"Bias"
,
"FusedBatchNormAddActOp"
);
// check output
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Y"
),
"Output"
,
"Y"
,
"FusedBatchNormAddActOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"MeanOut"
),
"Output"
,
"MeanOut"
,
"FusedBatchNormAddActOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"VarianceOut"
),
"Output"
,
"VarianceOut"
,
"FusedBatchNormAddActOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedMean"
),
"Output"
,
"SavedMean"
,
"FusedBatchNormAddActOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SavedVariance"
),
"Output"
,
"SavedVariance"
,
"FusedBatchNormAddActOp"
);
const
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
const
auto
z_dims
=
ctx
->
GetInputDim
(
"Z"
);
PADDLE_ENFORCE_EQ
(
x_dims
,
z_dims
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: the shapes of input "
"must be equal. But received: the shape "
"of input X = [%s], and the shape of "
"input Y = [%s]"
,
x_dims
,
z_dims
));
PADDLE_ENFORCE_GE
(
x_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: the dimensions of input "
"must greater than or equal to 2."
"But received: the shape of input "
"= [%s], the dimension of input = "
"[%d]"
,
x_dims
,
x_dims
.
size
()));
PADDLE_ENFORCE_LE
(
x_dims
.
size
(),
5
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: the dimensions of input "
"must smaller than or equal to 5."
"But received: the shape of input "
"= [%s], the dimension of input = "
"[%d]"
,
x_dims
,
x_dims
.
size
()));
const
int64_t
C
=
x_dims
[
x_dims
.
size
()
-
1
];
auto
scale_dim
=
ctx
->
GetInputDim
(
"Scale"
);
auto
bias_dim
=
ctx
->
GetInputDim
(
"Bias"
);
PADDLE_ENFORCE_EQ
(
scale_dim
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: the dimension of scale must equal to 1."
"But received: the shape of scale is [%s], the dimension "
"of scale is [%d]"
,
scale_dim
,
scale_dim
.
size
()));
PADDLE_ENFORCE_EQ
(
bias_dim
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: the dimension of bias must equal to 1."
"But received: the shape of bias is [%s],the dimension "
"of bias is [%d]"
,
bias_dim
,
bias_dim
.
size
()));
bool
check
=
true
;
if
((
!
ctx
->
IsRuntime
())
&&
(
framework
::
product
(
scale_dim
)
<=
0
||
framework
::
product
(
bias_dim
)
<=
0
))
{
check
=
false
;
}
if
(
check
)
{
PADDLE_ENFORCE_EQ
(
scale_dim
[
0
],
C
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: the shape of scale must equal to [%d]"
"But received: the shape of scale is [%d]"
,
C
,
scale_dim
[
0
]));
PADDLE_ENFORCE_EQ
(
bias_dim
[
0
],
C
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: the shape of bias must equal to [%d]"
"But received: the shape of bias is [%d]"
,
C
,
bias_dim
[
0
]));
}
ctx
->
SetOutputDim
(
"Y"
,
x_dims
);
ctx
->
SetOutputDim
(
"MeanOut"
,
{
C
});
ctx
->
SetOutputDim
(
"VarianceOut"
,
{
C
});
ctx
->
SetOutputDim
(
"SavedMean"
,
{
C
});
ctx
->
SetOutputDim
(
"SavedVariance"
,
{
C
});
ctx
->
ShareLoD
(
"X"
,
"Y"
);
}
framework
::
OpKernelType
FusedBatchNormAddActOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
);
// By default, the type of the scale, bias, mean,
// and var tensors should be float when input tensor's dtype is float16.
auto
bn_param_type
=
framework
::
proto
::
VarType
::
FP32
;
PADDLE_ENFORCE_EQ
(
bn_param_type
,
ctx
.
Input
<
Tensor
>
(
"Scale"
)
->
type
(),
platform
::
errors
::
InvalidArgument
(
"Scale input should be of float type"
));
PADDLE_ENFORCE_EQ
(
bn_param_type
,
ctx
.
Input
<
Tensor
>
(
"Bias"
)
->
type
(),
platform
::
errors
::
InvalidArgument
(
"Bias input should be of float type"
));
framework
::
LibraryType
library
=
framework
::
LibraryType
::
kPlain
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
(),
layout
,
library
);
}
void
FusedBatchNormAddActOpMaker
::
Make
()
{
AddInput
(
"X"
,
"The input tensor"
);
AddInput
(
"Z"
,
"The input tensor"
);
AddInput
(
"Scale"
,
"Scale is a 1-dimensional tensor of size C "
"that is applied to the output"
);
AddInput
(
"Bias"
,
"Bias is a 1-dimensional tensor of size C "
"that is applied to the output"
);
AddOutput
(
"Y"
,
"result after normalization"
);
AddOutput
(
"MeanOut"
,
"Share memory with Mean. "
"Store the global mean when training"
);
AddOutput
(
"VarianceOut"
,
"Share memory with Variance. "
"Store the global Variance when training"
);
AddOutput
(
"SavedMean"
,
"Mean of the current mini batch, "
"will apply to output when training"
)
.
AsIntermediate
();
AddOutput
(
"SavedVariance"
,
"Variance of the current mini batch, "
"will apply to output when training"
)
.
AsIntermediate
();
AddOutput
(
"ReserveSpace"
,
"Reserve GPU space for triggering the new semi-persistent "
"NHWC kernel"
);
AddAttr
<
float
>
(
"momentum"
,
""
).
SetDefault
(
0.9
);
AddAttr
<
float
>
(
"epsilon"
,
""
)
.
SetDefault
(
1e-5
)
.
AddCustomChecker
([](
const
float
&
epsilon
)
{
PADDLE_ENFORCE_EQ
(
epsilon
>=
0.0
f
&&
epsilon
<=
0.001
f
,
true
,
platform
::
errors
::
InvalidArgument
(
"'epsilon' should be between 0.0 and 0.001."
));
});
AddAttr
<
std
::
string
>
(
"act_type"
,
"The activation type to be fused."
)
.
SetDefault
(
"relu"
);
AddComment
(
R"DOC(
Fused Batch Normalization with activation.
Batch Norm has been implemented as discussed in the paper:
https://arxiv.org/pdf/1502.03167.pdf
Batch Norm can be used as a normalizer function for conv2d and fully_connected operations.
Now, the required data format for FusedBatchNormAddActOp is NHWC `[batch, in_height, in_width, in_channels]`.
)DOC"
);
}
void
FusedBatchNormAddActGradOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
// check input
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"FusedBatchNormAddActGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Z"
),
"Input"
,
"Z"
,
"FusedBatchNormAddActGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Scale"
),
"Input"
,
"Scale"
,
"FusedBatchNormAddActGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedMean"
),
"Input"
,
"SavedMean"
,
"FusedBatchNormAddActGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SavedVariance"
),
"Input"
,
"SavedVariance"
,
"FusedBatchNormAddActGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Y"
)),
"Input"
,
framework
::
GradVarName
(
"Y"
),
"FusedBatchNormAddActGradOp"
);
// check output
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output"
,
framework
::
GradVarName
(
"X"
),
"FusedBatchNormAddActGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Z"
)),
"Output"
,
framework
::
GradVarName
(
"Z"
),
"FusedBatchNormAddActGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Scale"
)),
"Output"
,
framework
::
GradVarName
(
"Scale"
),
"FusedBatchNormAddActGradOp"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Bias"
)),
"Output"
,
framework
::
GradVarName
(
"Bias"
),
"FusedBatchNormAddActGradOp"
);
const
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
const
int
C
=
in_dims
[
in_dims
.
size
()
-
1
];
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
in_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Z"
),
in_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Scale"
),
{
C
});
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Bias"
),
{
C
});
}
framework
::
OpKernelType
FusedBatchNormAddActGradOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
const
auto
*
var
=
ctx
.
InputVar
(
framework
::
GradVarName
(
"Y"
));
if
(
var
==
nullptr
)
{
PADDLE_THROW
(
platform
::
errors
::
NotFound
(
"Can not find Y@GRAD in the execution context."
));
}
const
Tensor
*
t
=
nullptr
;
if
(
var
->
IsType
<
Tensor
>
())
{
t
=
&
var
->
Get
<
Tensor
>
();
}
else
if
(
var
->
IsType
<
LoDTensor
>
())
{
t
=
&
var
->
Get
<
LoDTensor
>
();
}
if
(
t
==
nullptr
)
{
PADDLE_THROW
(
platform
::
errors
::
NotFound
(
"Can not get the tensor value of Y@GRAD."
));
}
framework
::
LibraryType
library
=
framework
::
LibraryType
::
kPlain
;
framework
::
DataLayout
layout
=
framework
::
DataLayout
::
kAnyLayout
;
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
(),
layout
,
library
);
}
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
fused_bn_add_activation
,
ops
::
FusedBatchNormAddActOp
,
ops
::
FusedBatchNormAddActOpMaker
,
ops
::
FusedBatchNormAddActOpInferVarType
,
ops
::
FusedBatchNormAddActGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
FusedBatchNormAddActGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
fused_bn_add_activation_grad
,
ops
::
FusedBatchNormAddActGradOp
);
paddle/fluid/operators/fused/fused_bn_add_activation_op.cu
0 → 100644
浏览文件 @
906e7f92
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <cfloat>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/fused/fused_bn_add_activation_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/norm_utils.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
DECLARE_bool
(
cudnn_batchnorm_spatial_persistent
);
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
using
CudnnDataType
=
platform
::
CudnnDataType
<
T
>
;
template
<
typename
T
>
using
BatchNormParamType
=
typename
CudnnDataType
<
T
>::
BatchNormParamType
;
template
<
typename
T
>
class
FusedBatchNormAddActKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"It must use CUDAPlace."
));
double
epsilon
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
float
momentum
=
ctx
.
Attr
<
float
>
(
"momentum"
);
std
::
string
act_type
=
ctx
.
Attr
<
std
::
string
>
(
"act_type"
);
if
(
epsilon
<=
CUDNN_BN_MIN_EPSILON
-
FLT_EPSILON
)
{
LOG
(
ERROR
)
<<
"Provided epsilon is smaller than "
<<
"CUDNN_BN_MIN_EPSILON. Setting it to "
<<
"CUDNN_BN_MIN_EPSILON instead."
;
}
epsilon
=
std
::
max
(
epsilon
,
CUDNN_BN_MIN_EPSILON
);
// Get the size for each dimension.
// NHWC [batch_size, in_height, in_width, in_channels]
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
*
z
=
ctx
.
Input
<
Tensor
>
(
"Z"
);
const
auto
&
in_dims
=
x
->
dims
();
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
auto
*
mean_out
=
ctx
.
Output
<
Tensor
>
(
"MeanOut"
);
auto
*
variance_out
=
ctx
.
Output
<
Tensor
>
(
"VarianceOut"
);
mean_out
->
mutable_data
<
BatchNormParamType
<
T
>>
(
ctx
.
GetPlace
());
variance_out
->
mutable_data
<
BatchNormParamType
<
T
>>
(
ctx
.
GetPlace
());
auto
*
saved_mean
=
ctx
.
Output
<
Tensor
>
(
"SavedMean"
);
auto
*
saved_variance
=
ctx
.
Output
<
Tensor
>
(
"SavedVariance"
);
saved_mean
->
mutable_data
<
BatchNormParamType
<
T
>>
(
ctx
.
GetPlace
());
saved_variance
->
mutable_data
<
BatchNormParamType
<
T
>>
(
ctx
.
GetPlace
());
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
y
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
N
,
C
,
H
,
W
,
D
;
const
DataLayout
data_layout
=
DataLayout
::
kNHWC
;
ExtractNCWHD
(
in_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
// ------------------- cudnn descriptors ---------------------
auto
handle
=
dev_ctx
.
cudnn_handle
();
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
bn_param_desc_
;
cudnnBatchNormMode_t
mode_
=
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
bn_param_desc_
));
std
::
vector
<
int
>
dims
=
{
N
,
C
,
H
,
W
,
D
};
std
::
vector
<
int
>
strides
=
{
H
*
W
*
D
*
C
,
1
,
W
*
D
*
C
,
D
*
C
,
C
};
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
in_dims
.
size
()
>
3
?
in_dims
.
size
()
:
4
,
dims
.
data
(),
strides
.
data
()));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDeriveBNTensorDescriptor
(
bn_param_desc_
,
data_desc_
,
mode_
));
double
this_factor
=
1.
-
momentum
;
cudnnBatchNormOps_t
bnOps_
=
CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
;
platform
::
ScopedActivationDescriptor
scope_act_desc
;
cudnnActivationDescriptor_t
activation_desc_
=
scope_act_desc
.
descriptor
<
T
>
(
act_type
);
size_t
workspace_size
=
0
;
size_t
reserve_space_size
=
0
;
void
*
reserve_space_ptr
=
nullptr
;
void
*
workspace_ptr
=
nullptr
;
Tensor
workspace_tensor
;
// Create reserve space and workspace for batch norm.
// Create tensor for each batchnorm op, it will be used in the
// backward. Thus this tensor shouldn't be temp.
auto
*
reserve_space
=
ctx
.
Output
<
Tensor
>
(
"ReserveSpace"
);
PADDLE_ENFORCE_NOT_NULL
(
reserve_space
,
platform
::
errors
::
NotFound
(
"The argument ReserveSpace of batch_norm op is not found."
));
// --------------- cudnn batchnorm workspace ---------------
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize
(
/*handle=*/
handle
,
/*mode=*/
mode_
,
/*bnOps=*/
bnOps_
,
/*xDesc=*/
data_desc_
,
/*zDesc=*/
data_desc_
,
/*yDesc=*/
data_desc_
,
/*bnScaleBiasMeanVarDesc=*/
bn_param_desc_
,
/*activationDesc=*/
activation_desc_
,
/*sizeInBytes=*/
&
workspace_size
));
// -------------- cudnn batchnorm reserve space --------------
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnGetBatchNormalizationTrainingExReserveSpaceSize
(
/*handle=*/
handle
,
/*mode=*/
mode_
,
/*bnOps=*/
bnOps_
,
/*activationDesc=*/
activation_desc_
,
/*xDesc=*/
data_desc_
,
/*sizeInBytes=*/
&
reserve_space_size
));
reserve_space_ptr
=
reserve_space
->
mutable_data
(
ctx
.
GetPlace
(),
x
->
type
(),
reserve_space_size
);
workspace_ptr
=
workspace_tensor
.
mutable_data
(
ctx
.
GetPlace
(),
x
->
type
(),
workspace_size
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnBatchNormalizationForwardTrainingEx
(
handle
,
mode_
,
bnOps_
,
CudnnDataType
<
T
>::
kOne
(),
CudnnDataType
<
T
>::
kZero
(),
data_desc_
,
x
->
template
data
<
T
>(),
data_desc_
,
z
->
template
data
<
T
>(),
data_desc_
,
y
->
template
data
<
T
>(),
bn_param_desc_
,
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
this_factor
,
mean_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
variance_out
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
epsilon
,
saved_mean
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
saved_variance
->
template
mutable_data
<
BatchNormParamType
<
T
>
>
(
ctx
.
GetPlace
()),
activation_desc_
,
workspace_ptr
,
workspace_size
,
reserve_space_ptr
,
reserve_space_size
));
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
bn_param_desc_
));
}
};
template
<
typename
T
>
class
FusedBatchNormAddActGradKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"It must use CUDAPlace."
));
double
epsilon
=
static_cast
<
double
>
(
ctx
.
Attr
<
float
>
(
"epsilon"
));
std
::
string
act_type
=
ctx
.
Attr
<
std
::
string
>
(
"act_type"
);
const
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
const
auto
*
z
=
ctx
.
Input
<
Tensor
>
(
"Z"
);
const
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
const
auto
*
d_y
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
const
auto
*
scale
=
ctx
.
Input
<
Tensor
>
(
"Scale"
);
const
auto
*
bias
=
ctx
.
Input
<
Tensor
>
(
"Bias"
);
const
auto
*
reserve_space
=
ctx
.
Input
<
Tensor
>
(
"ReserveSpace"
);
const
auto
&
in_dims
=
x
->
dims
();
int
N
,
C
,
H
,
W
,
D
;
const
DataLayout
data_layout
=
DataLayout
::
kNHWC
;
ExtractNCWHD
(
in_dims
,
data_layout
,
&
N
,
&
C
,
&
H
,
&
W
,
&
D
);
// init output
auto
*
d_x
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
d_z
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Z"
));
auto
*
d_scale
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Scale"
));
auto
*
d_bias
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Bias"
));
d_x
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
d_z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_EQ
(
d_scale
&&
d_bias
,
true
,
platform
::
errors
::
PreconditionNotMet
(
"Both the scale grad and the bias grad must not be null."
));
d_scale
->
mutable_data
<
BatchNormParamType
<
T
>>
(
ctx
.
GetPlace
());
d_bias
->
mutable_data
<
BatchNormParamType
<
T
>>
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_EQ
(
scale
->
dims
().
size
(),
1UL
,
platform
::
errors
::
PreconditionNotMet
(
"The scale only has one dimension."
));
PADDLE_ENFORCE_EQ
(
scale
->
dims
()[
0
],
C
,
platform
::
errors
::
PreconditionNotMet
(
"The size of scale is equal to the channel of Input(X)."
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
std
::
vector
<
int
>
dims
=
{
N
,
C
,
H
,
W
,
D
};
std
::
vector
<
int
>
strides
=
{
H
*
W
*
C
*
D
,
1
,
W
*
D
*
C
,
D
*
C
,
C
};
// ------------------- cudnn descriptors ---------------------
cudnnTensorDescriptor_t
data_desc_
;
cudnnTensorDescriptor_t
bn_param_desc_
;
cudnnBatchNormMode_t
mode_
=
CUDNN_BATCHNORM_SPATIAL_PERSISTENT
;
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnCreateTensorDescriptor
(
&
bn_param_desc_
));
if
(
epsilon
<=
CUDNN_BN_MIN_EPSILON
-
FLT_EPSILON
)
{
LOG
(
ERROR
)
<<
"Provided epsilon is smaller than "
<<
"CUDNN_BN_MIN_EPSILON. Setting it to "
<<
"CUDNN_BN_MIN_EPSILON instead."
;
}
epsilon
=
std
::
max
(
epsilon
,
CUDNN_BN_MIN_EPSILON
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnSetTensorNdDescriptor
(
data_desc_
,
CudnnDataType
<
T
>::
type
,
in_dims
.
size
()
>
3
?
in_dims
.
size
()
:
4
,
dims
.
data
(),
strides
.
data
()));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDeriveBNTensorDescriptor
(
bn_param_desc_
,
data_desc_
,
mode_
));
const
auto
*
saved_mean
=
ctx
.
Input
<
Tensor
>
(
"SavedMean"
);
const
auto
*
saved_var
=
ctx
.
Input
<
Tensor
>
(
"SavedVariance"
);
const
auto
*
saved_mean_data
=
saved_mean
->
template
data
<
BatchNormParamType
<
T
>
>
();
const
auto
*
saved_var_data
=
saved_var
->
template
data
<
BatchNormParamType
<
T
>
>
();
size_t
workspace_size
=
0
;
void
*
workspace_ptr
=
nullptr
;
Tensor
workspace_tensor
;
auto
reserve_space_size
=
reserve_space
->
memory_size
();
cudnnBatchNormOps_t
bnOps_
=
CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
;
platform
::
ScopedActivationDescriptor
scope_act_desc
;
cudnnActivationDescriptor_t
activation_desc_
=
scope_act_desc
.
descriptor
<
T
>
(
act_type
);
// --------------- cudnn batchnorm workspace ---------------
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnGetBatchNormalizationBackwardExWorkspaceSize
(
/*handle=*/
dev_ctx
.
cudnn_handle
(),
/*mode=*/
mode_
,
/*bnOps=*/
bnOps_
,
/*xDesc=*/
data_desc_
,
/*yDesc=*/
data_desc_
,
/*dyDesc=*/
data_desc_
,
/*dzDesc=*/
data_desc_
,
/*dxDesc=*/
data_desc_
,
/*bnScaleBiasMeanVarDesc=*/
bn_param_desc_
,
/*activationDesc=*/
activation_desc_
,
/*sizeInBytes=*/
&
workspace_size
));
workspace_ptr
=
workspace_tensor
.
mutable_data
(
ctx
.
GetPlace
(),
x
->
type
(),
workspace_size
);
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnBatchNormalizationBackwardEx
(
/*handle=*/
dev_ctx
.
cudnn_handle
(),
/*mode=*/
mode_
,
/*bnOps=*/
bnOps_
,
/*alphaDataDiff=*/
CudnnDataType
<
T
>::
kOne
(),
/*betaDataDiff=*/
CudnnDataType
<
T
>::
kZero
(),
/*alphaParamDiff=*/
CudnnDataType
<
T
>::
kOne
(),
/*betaParamDiff=*/
CudnnDataType
<
T
>::
kZero
(),
/*xDesc=*/
data_desc_
,
/*xData=*/
x
->
template
data
<
T
>(),
/*yDesc=*/
data_desc_
,
/*yData=*/
y
->
template
data
<
T
>(),
/*dyDesc=*/
data_desc_
,
/*dyData=*/
d_y
->
template
data
<
T
>(),
/*dzDesc=*/
data_desc_
,
/*dzData=*/
d_z
->
template
data
<
T
>(),
/*dxDesc=*/
data_desc_
,
/*dxData=*/
d_x
->
template
data
<
T
>(),
/*dBnScaleBiasDesc=*/
bn_param_desc_
,
/*bnScaleData=*/
scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
/*bnBiasData=*/
bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
/*dBnScaleData=*/
d_scale
->
template
data
<
BatchNormParamType
<
T
>
>
(),
/*dBnBiasData=*/
d_bias
->
template
data
<
BatchNormParamType
<
T
>
>
(),
/*epsilon=*/
epsilon
,
/*savedMean=*/
saved_mean_data
,
/*savedInvVariance=*/
saved_var_data
,
/*activationDesmc=*/
activation_desc_
,
/*workspace=*/
workspace_ptr
,
/*workSpaceSizeInBytes=*/
workspace_size
,
/*reserveSpace=*/
const_cast
<
T
*>
(
reserve_space
->
template
data
<
T
>()),
/*reserveSpaceSizeInBytes=*/
reserve_space_size
));
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
data_desc_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cudnnDestroyTensorDescriptor
(
bn_param_desc_
));
}
};
}
// namespace operators
}
// namespace paddle
#if CUDNN_VERSION >= 7401
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
fused_bn_add_activation
,
ops
::
FusedBatchNormAddActKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
REGISTER_OP_CUDA_KERNEL
(
fused_bn_add_activation_grad
,
ops
::
FusedBatchNormAddActGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
);
#endif
paddle/fluid/operators/fused/fused_bn_add_activation_op.h
0 → 100644
浏览文件 @
906e7f92
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
FusedBatchNormAddActOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
class
FusedBatchNormAddActGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
class
FusedBatchNormAddActOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
;
};
template
<
typename
T
>
class
FusedBatchNormAddActGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
this
->
ForwardOpType
()
+
"_grad"
);
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
"Z"
,
this
->
Input
(
"Z"
));
op
->
SetInput
(
"Y"
,
this
->
Output
(
"Y"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Y"
),
this
->
OutputGrad
(
"Y"
));
op
->
SetInput
(
"Scale"
,
this
->
Input
(
"Scale"
));
op
->
SetInput
(
"Bias"
,
this
->
Input
(
"Bias"
));
op
->
SetInput
(
"SavedMean"
,
this
->
Output
(
"SavedMean"
));
op
->
SetInput
(
"SavedVariance"
,
this
->
Output
(
"SavedVariance"
));
op
->
SetInput
(
"ReserveSpace"
,
this
->
Output
(
"ReserveSpace"
));
op
->
SetAttrMap
(
this
->
Attrs
());
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Z"
),
this
->
InputGrad
(
"Z"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Scale"
),
this
->
InputGrad
(
"Scale"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Bias"
),
this
->
InputGrad
(
"Bias"
));
}
};
class
FusedBatchNormAddActOpInferVarType
:
public
framework
::
PassInDtypeAndVarTypeToOutput
{
protected:
std
::
unordered_map
<
std
::
string
,
std
::
string
>&
GetInputOutputWithSameType
()
const
override
{
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
m
{{
"X"
,
/*->*/
"Y"
}};
return
m
;
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
FusedBatchNormAddActKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
template
<
typename
DeviceContext
,
typename
T
>
class
FusedBatchNormAddActGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/contrib/layers/nn.py
浏览文件 @
906e7f92
...
...
@@ -45,6 +45,7 @@ from paddle.fluid.initializer import Normal, Constant, NumpyArrayInitializer
from
paddle.fluid.data_feeder
import
check_variable_and_dtype
,
check_type
,
check_dtype
,
convert_dtype
from
paddle.fluid
import
core
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.entry_attr
import
ProbabilityEntry
,
CountFilterEntry
from
paddle.fluid.framework
import
Variable
,
convert_np_dtype_to_dtype_
...
...
@@ -57,7 +58,7 @@ __all__ = [
'multiclass_nms2'
,
'search_pyramid_hash'
,
'shuffle_batch'
,
'partial_concat'
,
'sparse_embedding'
,
'partial_sum'
,
'tdm_child'
,
'rank_attention'
,
'tdm_sampler'
,
'batch_fc'
,
'_pull_box_extended_sparse'
,
'bilateral_slice'
,
'correlation'
'correlation'
,
'fused_bn_add_act'
]
...
...
@@ -1625,3 +1626,191 @@ def correlation(x,
},
outputs
=
{
"Output"
:
output
})
return
output
def
fused_bn_add_act
(
x
,
y
,
momentum
=
0.9
,
epsilon
=
1e-05
,
param_attr
=
None
,
bias_attr
=
None
,
moving_mean_name
=
None
,
moving_variance_name
=
None
,
act
=
None
,
name
=
None
):
"""
This Op performs batch norm on input x, and adds the result to input y. Then
it performs activation on the sum. The data format of inputs must be NHWC
`[batch, in_height, in_width, in_channels]`.
Args:
x(Tensor): The rank of input tensor can be 2, 3, 4, 5. The data type
is float16.
y(Tensor): The rank of input tensor can be 2, 3, 4, 5. The data type
is float16.
momentum(float|Tensor, optional): The value used for the moving_mean and
moving_var computation. This should be a float number or a tensor with
shape [1] and data type as float32. The updated formula is:
:math:`moving\_mean = moving\_mean * momentum + new\_mean * (1. - momentum)`
:math:`moving\_var = moving\_var * momentum + new\_var * (1. - momentum)`
Default is 0.9.
epsilon(float, optional): A value added to the denominator for
numerical stability. Default is 1e-5.
param_attr(ParamAttr, optional): The parameter attribute for Parameter `scale`
of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as param_attr, the name of scale can be set in ParamAttr.
If the Initializer of the param_attr is not set, the parameter is initialized
with Xavier. Default: None.
bias_attr(ParamAttr, optional): The parameter attribute for the bias of batch_norm.
If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr.
If the Initializer of the bias_attr is not set, the bias is initialized zero.
Default: None.
moving_mean_name(str, optional): The name of moving_mean which store the global Mean. If it
is set to None, batch_norm will save global mean with a random name, otherwise, batch_norm
will save global mean with the string.
moving_variance_name(str, optional): The name of the moving_variance which store the global Variance.
If it is set to None, batch_norm will save global variance with a random name, otherwise, batch_norm
will save global variance with the string.
act(string, optional): Activation type, linear|relu|prelu|...
name(str, optional): For detailed information, please refer to :ref:`api_guide_Name`.
Usually name is no need to set and None by default.
Examples:
.. code-block:: python
import paddle.fluid as fluid
def build_program(main_program, startup_program):
with fluid.program_guard(main_program, startup_program):
x = fluid.layers.data(name='x', shape=[1, 28, 28], dtype='float32')
y = fluid.layers.data(name="y", shape=[1], dtype='int64')
conv1_1 = fluid.layers.conv2d(
input=x,
filter_size=3,
num_filters=32,
stride=1,
padding=1,
act=None,
bias_attr=False,
data_format='NHWC')
conv1_2 = fluid.layers.conv2d(
input=x,
filter_size=3,
num_filters=32,
stride=1,
padding=1,
act=None,
bias_attr=False,
data_format='NHWC')
bn = fluid.layers.batch_norm(
input=conv1_1,
act=None,
data_layout='NHWC')
fused_bn_add_act = fluid.contrib.layers.fused_bn_add_act(conv1_2, bn)
prediction = fluid.layers.fc(input=fused_bn_add_act, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=y)
loss = fluid.layers.mean(loss)
sgd = fluid.optimizer.SGD(learning_rate=0.001)
sgd = fluid.contrib.mixed_precision.decorate(
sgd, use_dynamic_loss_scaling=True, init_loss_scaling=128.0)
sgd.minimize(loss)
return x, y, loss
iters = 5
batch_size = 16
support_gpu = fluid.is_compiled_with_cuda()
if support_gpu:
main_program = fluid.Program()
startup_program = fluid.Program()
place = fluid.CUDAPlace(0)
x, y, loss = build_program(main_program, startup_program)
feeder = fluid.DataFeeder(feed_list=[x, y], place=place)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size)
exe = fluid.Executor(place)
scope = fluid.Scope()
with fluid.scope_guard(scope):
exe.run(startup_program)
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(main_program, feed=feeder.feed(data), fetch_list=[loss])
"""
helper
=
LayerHelper
(
'fused_bn_add_act'
,
**
locals
())
check_variable_and_dtype
(
x
,
'input'
,
[
'float16'
,
'float32'
,
'float64'
],
'fused_bn_add_act'
)
check_variable_and_dtype
(
y
,
'input'
,
[
'float16'
,
'float32'
,
'float64'
],
'fused_bn_add_act'
)
bn_param_dtype
=
core
.
VarDesc
.
VarType
.
FP32
x_shape
=
x
.
shape
channel_num
=
x_shape
[
-
1
]
param_shape
=
[
channel_num
]
# create parameter
scale
=
helper
.
create_parameter
(
attr
=
helper
.
param_attr
,
shape
=
param_shape
,
dtype
=
bn_param_dtype
,
default_initializer
=
Constant
(
1.0
))
bias
=
helper
.
create_parameter
(
attr
=
helper
.
bias_attr
,
shape
=
param_shape
,
dtype
=
bn_param_dtype
,
is_bias
=
True
)
mean
=
helper
.
create_parameter
(
attr
=
ParamAttr
(
name
=
moving_mean_name
,
initializer
=
Constant
(
0.0
),
trainable
=
False
),
shape
=
param_shape
,
dtype
=
bn_param_dtype
)
mean
.
stop_gradient
=
True
variance
=
helper
.
create_parameter
(
attr
=
ParamAttr
(
name
=
moving_variance_name
,
initializer
=
Constant
(
1.0
),
trainable
=
False
),
shape
=
param_shape
,
dtype
=
bn_param_dtype
)
variance
.
stop_gradient
=
True
# create output
# mean and mean_out share the same memory
mean_out
=
mean
# variance and variance out share the same memory
variance_out
=
variance
saved_mean
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
saved_variance
=
helper
.
create_variable_for_type_inference
(
dtype
=
bn_param_dtype
,
stop_gradient
=
True
)
reserve_space
=
helper
.
create_variable_for_type_inference
(
dtype
=
core
.
VarDesc
.
VarType
.
FP16
,
stop_gradient
=
True
)
batch_norm_out
=
helper
.
create_variable_for_type_inference
(
core
.
VarDesc
.
VarType
.
FP16
)
inputs
=
{
"X"
:
x
,
"Z"
:
y
,
"Scale"
:
scale
,
"Bias"
:
bias
,
}
attrs
=
{
"epsilon"
:
epsilon
,
'momentum'
:
momentum
}
outputs
=
{
"Y"
:
batch_norm_out
,
"MeanOut"
:
mean_out
,
"VarianceOut"
:
variance_out
,
"SavedMean"
:
saved_mean
,
"SavedVariance"
:
saved_variance
,
"ReserveSpace"
:
reserve_space
}
helper
.
append_op
(
type
=
"fused_bn_add_activation"
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
return
batch_norm_out
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
浏览文件 @
906e7f92
...
...
@@ -135,6 +135,7 @@ gray_list = {
'get_tensor_from_selected_rows'
,
'sign'
,
'cast'
,
'fused_bn_add_activation'
,
}
'''
# The set of ops that don't support fp16 calculation
...
...
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
浏览文件 @
906e7f92
...
...
@@ -69,8 +69,10 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
]
for
in_name
in
op
.
input_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
type
==
'batch_norm'
:
if
in_name
!=
'X'
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
op
.
type
in
[
'batch_norm'
,
'fused_bn_add_activation'
]:
if
in_name
not
in
{
'X'
,
'Z'
}:
continue
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
block
.
var
(
in_var_name
)
...
...
@@ -102,7 +104,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
op
.
_set_attr
(
'in_dtype'
,
dest_dtype
)
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
dest_dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
for
out_name
in
op
.
output_names
:
if
op
.
type
==
'batch_norm'
and
out_name
!=
'Y'
:
if
op
.
type
in
[
'batch_norm'
,
'fused_bn_add_activation'
]
and
out_name
!=
'Y'
:
continue
for
out_var_name
in
op
.
output
(
out_name
):
out_var
=
block
.
var
(
out_var_name
)
...
...
python/paddle/fluid/tests/unittests/test_fused_bn_add_act.py
0 → 100644
浏览文件 @
906e7f92
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
core
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"Paddle core is not compiled with CUDA"
)
class
TestFusedBnAddActAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
conv_param_attr1
=
fluid
.
ParamAttr
(
name
=
'conv2d_1.weight'
,
initializer
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
learning_rate
=
0.001
)
self
.
conv_param_attr2
=
fluid
.
ParamAttr
(
name
=
'conv2d_2.weight'
,
initializer
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
),
learning_rate
=
0.001
)
self
.
bn_param_attr1
=
fluid
.
ParamAttr
(
name
=
'batch_norm_w_1'
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
1.0
))
self
.
bn_bias_attr1
=
fluid
.
ParamAttr
(
name
=
'batch_norm_b_1'
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.0
))
self
.
bn_param_attr2
=
fluid
.
ParamAttr
(
name
=
'batch_norm_w_2'
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
1.0
))
self
.
bn_bias_attr2
=
fluid
.
ParamAttr
(
name
=
'batch_norm_b_2'
,
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.0
))
self
.
fc_param_attr
=
fluid
.
ParamAttr
(
name
=
'fc.weight'
,
initializer
=
fluid
.
initializer
.
Xavier
(
uniform
=
False
))
def
build_fused_program
(
self
,
main_program
,
startup_program
,
use_cuda
,
seed
=
1
):
with
fluid
.
program_guard
(
main_program
,
startup_program
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
"y"
,
shape
=
[
1
],
dtype
=
'int64'
)
conv1_1
=
fluid
.
layers
.
conv2d
(
input
=
x
,
filter_size
=
3
,
num_filters
=
32
,
stride
=
1
,
padding
=
1
,
act
=
None
,
param_attr
=
self
.
conv_param_attr1
,
bias_attr
=
False
,
data_format
=
'NHWC'
)
conv1_2
=
fluid
.
layers
.
conv2d
(
input
=
x
,
filter_size
=
3
,
num_filters
=
32
,
stride
=
1
,
padding
=
1
,
act
=
None
,
param_attr
=
self
.
conv_param_attr2
,
bias_attr
=
False
,
data_format
=
'NHWC'
)
bn
=
fluid
.
layers
.
batch_norm
(
input
=
conv1_1
,
param_attr
=
self
.
bn_param_attr1
,
bias_attr
=
self
.
bn_bias_attr1
,
act
=
None
,
data_layout
=
'NHWC'
)
fused_bn_add_act
=
fluid
.
contrib
.
layers
.
fused_bn_add_act
(
conv1_2
,
bn
,
param_attr
=
self
.
bn_param_attr2
,
bias_attr
=
self
.
bn_bias_attr2
)
prediction
=
fluid
.
layers
.
fc
(
input
=
fused_bn_add_act
,
size
=
10
,
act
=
'softmax'
,
param_attr
=
self
.
fc_param_attr
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
y
)
loss
=
fluid
.
layers
.
mean
(
loss
)
sgd
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
sgd
=
fluid
.
contrib
.
mixed_precision
.
decorate
(
sgd
,
use_dynamic_loss_scaling
=
True
,
init_loss_scaling
=
128.0
)
sgd
.
minimize
(
loss
)
return
x
,
y
,
loss
def
build_origin_program
(
self
,
main_program
,
startup_program
,
use_cuda
,
seed
=
1
):
with
fluid
.
program_guard
(
main_program
,
startup_program
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
1
,
28
,
28
],
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
"y"
,
shape
=
[
1
],
dtype
=
'int64'
)
conv1_1
=
fluid
.
layers
.
conv2d
(
input
=
x
,
filter_size
=
3
,
num_filters
=
32
,
stride
=
1
,
padding
=
1
,
act
=
None
,
param_attr
=
self
.
conv_param_attr1
,
bias_attr
=
False
,
data_format
=
'NHWC'
)
conv1_2
=
fluid
.
layers
.
conv2d
(
input
=
x
,
filter_size
=
3
,
num_filters
=
32
,
stride
=
1
,
padding
=
1
,
act
=
None
,
param_attr
=
self
.
conv_param_attr2
,
bias_attr
=
False
,
data_format
=
'NHWC'
)
bn1
=
fluid
.
layers
.
batch_norm
(
input
=
conv1_1
,
param_attr
=
self
.
bn_param_attr1
,
bias_attr
=
self
.
bn_bias_attr1
,
act
=
None
,
data_layout
=
'NHWC'
)
bn2
=
fluid
.
layers
.
batch_norm
(
input
=
conv1_2
,
param_attr
=
self
.
bn_param_attr2
,
bias_attr
=
self
.
bn_bias_attr2
,
act
=
None
,
data_layout
=
'NHWC'
)
out
=
bn1
+
bn2
out
=
fluid
.
layers
.
relu
(
out
)
prediction
=
fluid
.
layers
.
fc
(
input
=
out
,
size
=
10
,
act
=
'softmax'
,
param_attr
=
self
.
fc_param_attr
)
loss
=
fluid
.
layers
.
cross_entropy
(
input
=
prediction
,
label
=
y
)
loss
=
fluid
.
layers
.
mean
(
loss
)
sgd
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
sgd
=
fluid
.
contrib
.
mixed_precision
.
decorate
(
sgd
,
use_dynamic_loss_scaling
=
True
,
init_loss_scaling
=
128.0
)
sgd
.
minimize
(
loss
)
return
x
,
y
,
loss
def
check
(
self
,
place
,
use_cuda
):
paddle
.
manual_seed
(
1
)
paddle
.
framework
.
random
.
_manual_program_seed
(
1
)
iters
=
5
batch_size
=
16
# build_fused_program
main_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
x
,
y
,
loss
=
self
.
build_fused_program
(
main_program
,
startup_program
,
use_cuda
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
x
,
y
],
place
=
place
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
batch_size
)
exe
=
fluid
.
Executor
(
place
)
loss_vals_fused
=
[]
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
exe
.
run
(
startup_program
)
for
_
in
range
(
iters
):
data
=
next
(
train_reader
())
loss_v
=
exe
.
run
(
main_program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
loss
])
loss_vals_fused
.
append
(
loss_v
[
0
][
0
])
# build_origin_program
main_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
x
,
y
,
loss
=
self
.
build_origin_program
(
main_program
,
startup_program
,
use_cuda
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
x
,
y
],
place
=
place
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
batch_size
)
loss_vals
=
[]
scope
=
fluid
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
exe
.
run
(
startup_program
)
for
_
in
range
(
iters
):
data
=
next
(
train_reader
())
loss_v
=
exe
.
run
(
main_program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
loss
])
loss_vals
.
append
(
loss_v
[
0
][
0
])
# check loss
for
i
in
range
(
iters
):
self
.
assertAlmostEqual
(
loss_vals
[
i
],
loss_vals_fused
[
i
],
delta
=
1e-5
)
def
test_fuse_bn_add_act
(
self
):
place
=
fluid
.
CUDAPlace
(
0
)
self
.
check
(
place
,
use_cuda
=
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录