提交 e870947c 编写于 作者: K Kexin Zhao

fix batch norm fp16 param type

上级 3233b2b3
...@@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("SavedVariance", {C}); ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y"); ctx->ShareLoD("X", "Y");
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
// For float or float16 input tensor, the type of the scale, bias, mean,
// and var tensors should both be float.
auto bn_param_type = framework::proto::VarType::FP32;
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
"Scale input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Bias")->type()),
"Bias input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Mean")->type()),
"Mean input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(
ctx.Input<Tensor>("Variance")->type()),
"Variance input should be of float type");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -26,6 +26,8 @@ using Tensor = framework::Tensor; ...@@ -26,6 +26,8 @@ using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout; using DataLayout = framework::DataLayout;
template <typename T> template <typename T>
using CudnnDataType = platform::CudnnDataType<T>; using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using bn_param_type = CudnnDataType<T>::bn_param_type;
void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout, void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
int *N, int *C, int *H, int *W, int *D) { int *N, int *C, int *H, int *W, int *D) {
...@@ -104,8 +106,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -104,8 +106,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type, data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
// Note: PERSISTENT not implemented for inference
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_)); bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_));
const auto *scale = ctx.Input<Tensor>("Scale"); const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias"); const auto *bias = ctx.Input<Tensor>("Bias");
...@@ -118,15 +121,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -118,15 +121,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
// alloc memory // alloc memory
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<T>(ctx.GetPlace()); mean_out->mutable_data<bn_param_type<T>>(ctx.GetPlace());
variance_out->mutable_data<T>(ctx.GetPlace()); variance_out->mutable_data<bn_param_type<T>>(ctx.GetPlace());
saved_mean->mutable_data<T>(ctx.GetPlace()); saved_mean->mutable_data<bn_param_type<T>>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace()); saved_variance->mutable_data<bn_param_type<T>>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> functor; math::SetConstant<platform::CUDADeviceContext, bn_param_type<T>> functor;
functor(dev_ctx, saved_mean, static_cast<T>(0)); functor(dev_ctx, saved_mean, static_cast<bn_param_type<T>>(0));
functor(dev_ctx, saved_variance, static_cast<T>(0)); functor(dev_ctx, saved_variance, static_cast<bn_param_type<T>>(0));
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
...@@ -147,8 +150,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -147,8 +150,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(), CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(), CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, y->template mutable_data<T>(ctx.GetPlace()), data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<T>(), bias->template data<T>(), bn_param_desc_, scale->template data<bn_param_type<T>>(),
est_mean->template data<T>(), est_var->template data<T>(), epsilon)); bias->template data<bn_param_type<T>>(),
est_mean->template data<bn_param_type<T>>(),
est_var->template data<bn_param_type<T>>(), epsilon));
} else { } else {
// Run training mode. // Run training mode.
// obtain running mean and running inv var, and see if we need to // obtain running mean and running inv var, and see if we need to
...@@ -159,11 +164,14 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -159,11 +164,14 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(), handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
data_desc_, x->template data<T>(), data_desc_, data_desc_, x->template data<T>(), data_desc_,
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_, y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(), bias->template data<T>(), this_factor, scale->template data<bn_param_type<T>>(),
mean_out->template mutable_data<T>(ctx.GetPlace()), bias->template data<bn_param_type<T>>(), this_factor,
variance_out->template mutable_data<T>(ctx.GetPlace()), epsilon, mean_out->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
saved_mean->template mutable_data<T>(ctx.GetPlace()), variance_out->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
saved_variance->template mutable_data<T>(ctx.GetPlace()))); epsilon,
saved_mean->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
saved_variance->template mutable_data<bn_param_type<T>>(
ctx.GetPlace())));
} }
// clean when exit. // clean when exit.
......
...@@ -85,6 +85,9 @@ template <> ...@@ -85,6 +85,9 @@ template <>
class CudnnDataType<float16> { class CudnnDataType<float16> {
public: public:
static const cudnnDataType_t type = CUDNN_DATA_HALF; static const cudnnDataType_t type = CUDNN_DATA_HALF;
// cudnn batch norm requires that Scale, Bias, Mean, and Variance
// to be FLOAT tensors when the input x is HALF tensor
static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT;
// The scaling param type is float for HALF and FLOAT tensors // The scaling param type is float for HALF and FLOAT tensors
typedef const float ScalingParamType; typedef const float ScalingParamType;
static ScalingParamType* kOne() { static ScalingParamType* kOne() {
...@@ -101,6 +104,7 @@ template <> ...@@ -101,6 +104,7 @@ template <>
class CudnnDataType<float> { class CudnnDataType<float> {
public: public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT; static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT;
typedef const float ScalingParamType; typedef const float ScalingParamType;
static ScalingParamType* kOne() { static ScalingParamType* kOne() {
static ScalingParamType v = 1.0; static ScalingParamType v = 1.0;
...@@ -116,6 +120,7 @@ template <> ...@@ -116,6 +120,7 @@ template <>
class CudnnDataType<double> { class CudnnDataType<double> {
public: public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
static const cudnnDataType_t bn_param_type = CUDNN_DATA_DOUBLE;
typedef const double ScalingParamType; typedef const double ScalingParamType;
static ScalingParamType* kOne() { static ScalingParamType* kOne() {
static ScalingParamType v = 1.0; static ScalingParamType v = 1.0;
......
...@@ -193,7 +193,7 @@ class TestBatchNormOpInference(OpTest): ...@@ -193,7 +193,7 @@ class TestBatchNormOpInference(OpTest):
def __assert_close(self, tensor, np_array, msg, atol=1e-4): def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def check_with_place(place, data_layout, dtype, shape): def check_with_place(self, place, data_layout, dtype, shape):
epsilon = 0.00001 epsilon = 0.00001
if len(shape) == 2: if len(shape) == 2:
x_shape = shape x_shape = shape
...@@ -209,11 +209,11 @@ class TestBatchNormOpInference(OpTest): ...@@ -209,11 +209,11 @@ class TestBatchNormOpInference(OpTest):
scale_shape = [c] scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(dtype) x_val = np.random.random_sample(x_shape).astype(dtype)
scale_val = np.random.random_sample(scale_shape).astype(dtype) scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(dtype) bias_val = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(dtype) mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(dtype) variance = np.ones(scale_shape).astype(np.float32)
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
epsilon, data_layout).astype(dtype) epsilon, data_layout).astype(dtype)
...@@ -266,9 +266,13 @@ class TestBatchNormOpInference(OpTest): ...@@ -266,9 +266,13 @@ class TestBatchNormOpInference(OpTest):
batch_norm_op.run(scope, place) batch_norm_op.run(scope, place)
# check inference result # check inference result
self.__assert_close(y_tensor, y_out, self.__assert_close(
"inference output are different at " + str(place) + y_tensor,
", " + data_layout + ", " + str(np.dtype(dtype))) y_out,
"inference output are different at " + str(place) + ", " +
data_layout + ", " + str(np.dtype(dtype)) +
str(np.array(y_tensor)) + str(y_out),
atol=2e-2)
def test_check_output(self): def test_check_output(self):
places = [core.CPUPlace()] places = [core.CPUPlace()]
...@@ -277,8 +281,9 @@ class TestBatchNormOpInference(OpTest): ...@@ -277,8 +281,9 @@ class TestBatchNormOpInference(OpTest):
for place in places: for place in places:
for data_format in ["NCHW", "NHWC"]: for data_format in ["NCHW", "NHWC"]:
check_with_place(place, data_format, self.dtype, [2, 3, 4, 5]) self.check_with_place(place, data_format, self.dtype,
check_with_place(place, data_format, self.dtype, [2, 3]) [2, 3, 4, 5])
self.check_with_place(place, data_format, self.dtype, [2, 3])
class TestFP16BatchNormOpInference(TestBatchNormOpInference): class TestFP16BatchNormOpInference(TestBatchNormOpInference):
...@@ -294,9 +299,9 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference): ...@@ -294,9 +299,9 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
for place in places: for place in places:
for data_format in ["NCHW", "NHWC"]: for data_format in ["NCHW", "NHWC"]:
check_output_with_place(place, data_format, self.dtype, self.check_with_place(place, data_format, self.dtype,
[2, 3, 4, 5]) [2, 3, 4, 5])
check_output_with_place(place, data_format, self.dtype, [2, 3]) self.check_with_place(place, data_format, self.dtype, [2, 3])
class TestBatchNormOpTraining(OpTest): class TestBatchNormOpTraining(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册