提交 e870947c 编写于 作者: K Kexin Zhao

fix batch norm fp16 param type

上级 3233b2b3
......@@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("X")->type());
// For float or float16 input tensor, the type of the scale, bias, mean,
// and var tensors should both be float.
auto bn_param_type = framework::proto::VarType::FP32;
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
"Scale input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Bias")->type()),
"Bias input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type,
framework::ToDataType(ctx.Input<Tensor>("Mean")->type()),
"Mean input should be of float type");
PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType(
ctx.Input<Tensor>("Variance")->type()),
"Variance input should be of float type");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -26,6 +26,8 @@ using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using bn_param_type = CudnnDataType<T>::bn_param_type;
void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
int *N, int *C, int *H, int *W, int *D) {
......@@ -104,8 +106,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
// Note: PERSISTENT not implemented for inference
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
bn_param_desc_, data_desc_, mode_));
bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
......@@ -118,15 +121,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
// alloc memory
y->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<T>(ctx.GetPlace());
variance_out->mutable_data<T>(ctx.GetPlace());
saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<bn_param_type<T>>(ctx.GetPlace());
variance_out->mutable_data<bn_param_type<T>>(ctx.GetPlace());
saved_mean->mutable_data<bn_param_type<T>>(ctx.GetPlace());
saved_variance->mutable_data<bn_param_type<T>>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
math::SetConstant<platform::CUDADeviceContext, T> functor;
functor(dev_ctx, saved_mean, static_cast<T>(0));
functor(dev_ctx, saved_variance, static_cast<T>(0));
math::SetConstant<platform::CUDADeviceContext, bn_param_type<T>> functor;
functor(dev_ctx, saved_mean, static_cast<bn_param_type<T>>(0));
functor(dev_ctx, saved_variance, static_cast<bn_param_type<T>>(0));
auto handle = dev_ctx.cudnn_handle();
......@@ -147,8 +150,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
bn_param_desc_, scale->template data<T>(), bias->template data<T>(),
est_mean->template data<T>(), est_var->template data<T>(), epsilon));
bn_param_desc_, scale->template data<bn_param_type<T>>(),
bias->template data<bn_param_type<T>>(),
est_mean->template data<bn_param_type<T>>(),
est_var->template data<bn_param_type<T>>(), epsilon));
} else {
// Run training mode.
// obtain running mean and running inv var, and see if we need to
......@@ -159,11 +164,14 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
data_desc_, x->template data<T>(), data_desc_,
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
scale->template data<T>(), bias->template data<T>(), this_factor,
mean_out->template mutable_data<T>(ctx.GetPlace()),
variance_out->template mutable_data<T>(ctx.GetPlace()), epsilon,
saved_mean->template mutable_data<T>(ctx.GetPlace()),
saved_variance->template mutable_data<T>(ctx.GetPlace())));
scale->template data<bn_param_type<T>>(),
bias->template data<bn_param_type<T>>(), this_factor,
mean_out->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
variance_out->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
epsilon,
saved_mean->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
saved_variance->template mutable_data<bn_param_type<T>>(
ctx.GetPlace())));
}
// clean when exit.
......
......@@ -85,6 +85,9 @@ template <>
class CudnnDataType<float16> {
public:
static const cudnnDataType_t type = CUDNN_DATA_HALF;
// cudnn batch norm requires that Scale, Bias, Mean, and Variance
// to be FLOAT tensors when the input x is HALF tensor
static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT;
// The scaling param type is float for HALF and FLOAT tensors
typedef const float ScalingParamType;
static ScalingParamType* kOne() {
......@@ -101,6 +104,7 @@ template <>
class CudnnDataType<float> {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT;
typedef const float ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
......@@ -116,6 +120,7 @@ template <>
class CudnnDataType<double> {
public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
static const cudnnDataType_t bn_param_type = CUDNN_DATA_DOUBLE;
typedef const double ScalingParamType;
static ScalingParamType* kOne() {
static ScalingParamType v = 1.0;
......
......@@ -193,7 +193,7 @@ class TestBatchNormOpInference(OpTest):
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def check_with_place(place, data_layout, dtype, shape):
def check_with_place(self, place, data_layout, dtype, shape):
epsilon = 0.00001
if len(shape) == 2:
x_shape = shape
......@@ -209,11 +209,11 @@ class TestBatchNormOpInference(OpTest):
scale_shape = [c]
x_val = np.random.random_sample(x_shape).astype(dtype)
scale_val = np.random.random_sample(scale_shape).astype(dtype)
bias_val = np.random.random_sample(scale_shape).astype(dtype)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(dtype)
variance = np.ones(scale_shape).astype(dtype)
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.ones(scale_shape).astype(np.float32)
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
epsilon, data_layout).astype(dtype)
......@@ -266,9 +266,13 @@ class TestBatchNormOpInference(OpTest):
batch_norm_op.run(scope, place)
# check inference result
self.__assert_close(y_tensor, y_out,
"inference output are different at " + str(place) +
", " + data_layout + ", " + str(np.dtype(dtype)))
self.__assert_close(
y_tensor,
y_out,
"inference output are different at " + str(place) + ", " +
data_layout + ", " + str(np.dtype(dtype)) +
str(np.array(y_tensor)) + str(y_out),
atol=2e-2)
def test_check_output(self):
places = [core.CPUPlace()]
......@@ -277,8 +281,9 @@ class TestBatchNormOpInference(OpTest):
for place in places:
for data_format in ["NCHW", "NHWC"]:
check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
check_with_place(place, data_format, self.dtype, [2, 3])
self.check_with_place(place, data_format, self.dtype,
[2, 3, 4, 5])
self.check_with_place(place, data_format, self.dtype, [2, 3])
class TestFP16BatchNormOpInference(TestBatchNormOpInference):
......@@ -294,9 +299,9 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
for place in places:
for data_format in ["NCHW", "NHWC"]:
check_output_with_place(place, data_format, self.dtype,
self.check_with_place(place, data_format, self.dtype,
[2, 3, 4, 5])
check_output_with_place(place, data_format, self.dtype, [2, 3])
self.check_with_place(place, data_format, self.dtype, [2, 3])
class TestBatchNormOpTraining(OpTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册