From 643a268e7991ad71e8146e78246397fafd83147f Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 21 Dec 2021 14:39:58 +0800 Subject: [PATCH] Support FP16 mean (#38289) * mean first version * fix scalar mean * add fp16 dtype for api --- .../kernel_primitives/functor_primitives.h | 15 ++- paddle/fluid/operators/mean_op.cu | 53 ++++----- .../operators/reduce_ops/reduce_functor_op.h | 2 +- .../operators/reduce_ops/reduce_mean_op.cu | 2 + .../operators/reduce_ops/reduce_mean_op.h | 13 +++ .../reduce_ops/reduce_mean_op.part.cu | 6 + .../fluid/operators/reduce_ops/reduce_op.cu.h | 57 ++++++--- paddle/fluid/operators/reduce_ops/reduce_op.h | 8 +- .../fluid/tests/unittests/test_mean_op.py | 109 +++++++++++++++++- python/paddle/tensor/stat.py | 3 +- 10 files changed, 205 insertions(+), 63 deletions(-) diff --git a/paddle/fluid/operators/kernel_primitives/functor_primitives.h b/paddle/fluid/operators/kernel_primitives/functor_primitives.h index d7aed8595ba..2bd8721b82f 100644 --- a/paddle/fluid/operators/kernel_primitives/functor_primitives.h +++ b/paddle/fluid/operators/kernel_primitives/functor_primitives.h @@ -14,7 +14,10 @@ #pragma once +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/platform/eigen_ext.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -74,16 +77,20 @@ struct IdentityFunctor { */ template struct DivideFunctor { - HOSTDEVICE inline DivideFunctor() { n_inv = static_cast(1.0f); } + private: + using MPType = typename ::paddle::operators::details::MPTypeTrait::Type; + + public: + HOSTDEVICE inline DivideFunctor() { n_inv = static_cast(1.0f); } - HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((Tx)(1.0 / n)) {} + HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((MPType)(1.0 / n)) {} HOSTDEVICE inline Ty operator()(const Tx& x) const { - return static_cast(x * n_inv); + return static_cast(static_cast(x) * n_inv); } private: - Tx n_inv; + MPType n_inv; }; /** diff --git a/paddle/fluid/operators/mean_op.cu b/paddle/fluid/operators/mean_op.cu index 1a10b7033f6..c48fc79326f 100644 --- a/paddle/fluid/operators/mean_op.cu +++ b/paddle/fluid/operators/mean_op.cu @@ -18,30 +18,23 @@ limitations under the License. */ #include namespace cub = hipcub; #endif +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" #include "paddle/fluid/operators/mean_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { -template -struct DivideFunctor { - HOSTDEVICE explicit inline DivideFunctor(int n) - : n_inv(static_cast(1.0 / n)) {} - - HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; } - - private: - T n_inv; -}; - template __global__ void MeanRunKernel(const T* in_data, T* out_data, int N) { + using MT = typename details::MPTypeTrait::Type; int idx = blockDim.x * blockIdx.x + threadIdx.x; - T data = in_data[0]; + auto data = static_cast(in_data[0]); for (; idx < N; idx += blockDim.x * gridDim.x) { - out_data[idx] = data / (static_cast(N)); + out_data[idx] = static_cast(data / (static_cast(N))); } } @@ -52,27 +45,29 @@ class MeanCUDAKernel : public framework::OpKernel { auto* input = context.Input("X"); auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); - auto size_prob = input->numel(); const T* in_data = input->data(); T* out_data = output->mutable_data(context.GetPlace()); + auto numel = input->numel(); + auto rank = input->dims().size(); + auto place = context.GetPlace(); auto stream = context.cuda_device_context().stream(); - DivideFunctor transformer(size_prob); - cub::TransformInputIterator, const T*> trans_x( - in_data, transformer); - size_t temp_storage_bytes = 0; + if (rank == 0) { // scalar + auto gpu_place = BOOST_GET(platform::CUDAPlace, place); + memory::Copy(gpu_place, out_data, gpu_place, in_data, numel * sizeof(T), + stream); + return; + } - auto err = cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, trans_x, - out_data, size_prob, stream); - PADDLE_ENFORCE_GPU_SUCCESS(err); - framework::Tensor tmp; - auto* temp_storage = tmp.mutable_data( - framework::make_ddim({static_cast(temp_storage_bytes)}), - context.GetPlace()); - err = cub::DeviceReduce::Sum(temp_storage, temp_storage_bytes, trans_x, - out_data, size_prob, stream); - PADDLE_ENFORCE_GPU_SUCCESS(err); + using MT = typename details::MPTypeTrait::Type; + using Div = kernel_primitives::DivideFunctor; + std::vector reduce_dims; + reduce_dims.reserve(rank); + for (decltype(rank) i = 0; i < rank; ++i) { + reduce_dims.push_back(i); + } + TensorReduceFunctorImpl( + *input, output, Div(numel), reduce_dims, stream); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_functor_op.h b/paddle/fluid/operators/reduce_ops/reduce_functor_op.h index dc79666b72f..72d21d7074e 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_functor_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_functor_op.h @@ -77,7 +77,7 @@ struct CustomSub { template struct CustomMean { - using Transformer = kps::DivideFunctor; + using Transformer = kps::DivideFunctor; inline Ty initial() { return static_cast(0.0f); } diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu index a50b09564fd..197ced2beaa 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.cu @@ -19,5 +19,7 @@ REGISTER_OP_CUDA_KERNEL( reduce_mean, ops::ReduceCudaKernel, + ops::ReduceCudaKernel, ops::ReduceCudaKernel, ops::ReduceCudaKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.h b/paddle/fluid/operators/reduce_ops/reduce_mean_op.h index 240c43bc6d0..2b2349b095c 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.h @@ -35,5 +35,18 @@ struct MeanGradFunctor { } }; +// TODO(zengjinle): Should refine the numeric stability of FP16 reduce_mean +// and reduce_mean_grad later. +struct FP16MeanGradFunctor { + template + void operator()(const DeviceContext& place, X* x, Y* y, DX* dx, DY* dy, + const Dim& dim, int size) { + dx->device(place) = (dy->template cast().broadcast(dim) / + dx->template cast().constant(size)) + .template cast(); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu index 0e133d5447f..4cc2577f6b2 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_mean_op.part.cu @@ -20,6 +20,12 @@ using CUDAReduceMeanGradKernel = ops::ReduceGradKernel; +using FP16CUDAReduceMeanGradKernel = + ops::ReduceGradKernel; + REGISTER_OP_CUDA_KERNEL(reduce_mean_grad, CUDAReduceMeanGradKernel, + FP16CUDAReduceMeanGradKernel, CUDAReduceMeanGradKernel, CUDAReduceMeanGradKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index 77fa5768843..5a82176a9c9 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -38,7 +38,9 @@ namespace cub = hipcub; #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/fast_divmod.h" +#include "paddle/fluid/string/string_helper.h" // Reduce split or not, Whether to use ReduceHigherDim #define REDUCE_SPLIT_BOUNDARY 512 @@ -814,11 +816,42 @@ static void LaunchReduceKernel(const Tx* x_data, Ty* y_data, } } +template class ReduceOp, + typename TransformOp> +static typename std::enable_if::value, + void>::type +CubTensorReduceFunctorImpl(const Tx* x_data, Ty* y_data, + const TransformOp& transform, int reduce_num, + const platform::Place& place, gpuStream_t stream) { + auto reducer = ReduceOp(); + cub::TransformInputIterator trans_x(x_data, + transform); + size_t temp_storage_bytes = 0; + cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, + reduce_num, reducer, reducer.initial(), stream); + framework::Tensor tmp; + auto* temp_storage = tmp.mutable_data( + framework::make_ddim({static_cast(temp_storage_bytes)}), place); + cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, + reduce_num, reducer, reducer.initial(), stream); +} + +template class ReduceOp, + typename TransformOp> +static typename std::enable_if::value, + void>::type +CubTensorReduceFunctorImpl(const Tx* x_data, Ty* y_data, + const TransformOp& transform, int reduce_num, + const platform::Place& place, gpuStream_t stream) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Tx should not be float16 when using cub::DeviceReduce::Reduce().")); +} + template class ReduceOp, typename TransformOp> void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, const TransformOp& transform, - std::vector origin_reduce_dims, + const std::vector& origin_reduce_dims, gpuStream_t stream) { auto x_dim = framework::vectorize(x.dims()); auto config = ReduceConfig(origin_reduce_dims, x_dim); @@ -848,25 +881,11 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, } config.SetOutputData(y_data, x.place(), &tmp); - bool use_cub_reduce = (config.reduce_num == numel) && - (!std::is_same::value); + constexpr bool kIsTxFP16 = std::is_same::value; + bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16; if (use_cub_reduce) { - // launch CUB::Reduce - auto reducer = ReduceOp(); - cub::TransformInputIterator trans_x(x_data, - transform); - size_t temp_storage_bytes = 0; - cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, - config.reduce_num, reducer, reducer.initial(), - stream); - framework::Tensor tmp; - auto* temp_storage = tmp.mutable_data( - framework::make_ddim({static_cast(temp_storage_bytes)}), - x.place()); - cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, - config.reduce_num, reducer, reducer.initial(), - stream); - + CubTensorReduceFunctorImpl( + x_data, y_data, transform, config.reduce_num, x.place(), stream); return; } diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 440fc1f7e37..d3b938272e6 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -703,7 +703,7 @@ class ReduceCudaKernel : public framework::OpKernel { std::vector reduce_dims = GetReduceDim(dims, input->dims().size(), reduce_all); int reduce_num = 1; - for (int i = 0; i < input->dims().size(); i++) { + for (auto i : reduce_dims) { reduce_num *= (input->dims())[i]; } gpuStream_t stream = context.cuda_device_context().stream(); @@ -713,8 +713,10 @@ class ReduceCudaKernel : public framework::OpKernel { TensorReduceFunc( *input, output, reduce_dims, reduce_num, stream)); } else { - TensorReduceFunctorImpl>( - *input, output, TransformOp(reduce_num), reduce_dims, stream); + using MPType = typename details::MPTypeTrait::Type; + TensorReduceFunctorImpl>( + *input, output, TransformOp(reduce_num), reduce_dims, + stream); } } }; diff --git a/python/paddle/fluid/tests/unittests/test_mean_op.py b/python/paddle/fluid/tests/unittests/test_mean_op.py index 064de404aff..7a49770e579 100644 --- a/python/paddle/fluid/tests/unittests/test_mean_op.py +++ b/python/paddle/fluid/tests/unittests/test_mean_op.py @@ -63,17 +63,25 @@ class TestMeanOpError(unittest.TestCase): class TestFP16MeanOp(TestMeanOp): def init_dtype_type(self): self.dtype = np.float16 + self.__class__.no_need_check_grad = True def test_check_output(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_output_with_place(place, atol=2e-3) + self.check_output_with_place(place) def test_checkout_grad(self): place = core.CUDAPlace(0) if core.is_float16_supported(place): - self.check_grad_with_place( - place, ['X'], 'Out', max_relative_error=0.8) + with fluid.dygraph.guard(): + x_np = np.random.random((10, 10)).astype(self.dtype) + x = paddle.to_tensor(x_np) + x.stop_gradient = False + y = fluid.layers.mean(x) + dx = paddle.grad(y, x)[0].numpy() + dx_expected = self.dtype(1.0 / np.prod(x_np.shape)) * np.ones( + x_np.shape).astype(self.dtype) + self.assertTrue(np.array_equal(dx, dx_expected)) @OpTestTool.skip_if_not_cpu_bf16() @@ -98,6 +106,14 @@ def ref_reduce_mean(x, axis=None, keepdim=False, reduce_all=False): return np.mean(x, axis=axis, keepdims=keepdim) +def ref_reduce_mean_grad(x, axis, dtype): + if reduce_all: + axis = list(range(x.ndim)) + + shape = [x.shape[i] for i in axis] + return (1.0 / np.prod(shape) * np.ones(shape)).astype(dtype) + + class TestReduceMeanOp(OpTest): def setUp(self): self.op_type = 'reduce_mean' @@ -105,11 +121,13 @@ class TestReduceMeanOp(OpTest): self.shape = [2, 3, 4, 5] self.axis = [0] self.keepdim = False - self.reduce_all = False self.set_attrs() np.random.seed(10) x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype) + if not hasattr(self, "reduce_all"): + self.reduce_all = (not self.axis) or len(self.axis) == len(x_np) + out_np = ref_reduce_mean(x_np, self.axis, self.keepdim, self.reduce_all) self.inputs = {'X': x_np} self.outputs = {'Out': out_np} @@ -119,14 +137,39 @@ class TestReduceMeanOp(OpTest): 'reduce_all': self.reduce_all } + if self.dtype == 'float16': + self.__class__.no_need_check_grad = True + def set_attrs(self): pass def test_check_output(self): - self.check_output() + if self.dtype != 'float16': + self.check_output() + else: + if not core.is_compiled_with_cuda(): + return + place = paddle.CUDAPlace(0) + self.check_output_with_place(place=place) def test_check_grad(self): - self.check_grad(['X'], ['Out']) + if self.dtype != 'float16': + self.check_grad(['X'], ['Out']) + else: + return + if not core.is_compiled_with_cuda(): + return + place = paddle.CUDAPlace(0) + if core.is_float16_supported(place): + return + with fluid.dygraph.guard(place=place): + x = paddle.tensor(self.inputs['X']) + y = paddle.mean( + x, axis=self.attrs['dim'], keepdim=self.attrs['keep_dim']) + dx = paddle.grad(y, x)[0].numpy() + dx_expected = ref_reduce_mean_grad( + self.inputs['X'], self.attrs['dim'], self.dtype) + self.assertTrue(np.array_equal(dx, dx_expected)) class TestReduceMeanOpDefaultAttrs(TestReduceMeanOp): @@ -146,47 +189,101 @@ class TestReduceMeanOpFloat32(TestReduceMeanOp): self.dtype = 'float32' +class TestReduceMeanOpFloat16(TestReduceMeanOp): + def set_attrs(self): + self.dtype = 'float16' + + class TestReduceMeanOpShape1D(TestReduceMeanOp): def set_attrs(self): self.shape = [100] +class TestReduceMeanOpShape1DFP16(TestReduceMeanOp): + def set_attrs(self): + self.shape = [100] + self.dtype = 'float16' + + class TestReduceMeanOpShape6D(TestReduceMeanOp): def set_attrs(self): self.shape = [2, 3, 4, 5, 6, 7] +class TestReduceMeanOpShape6DFP16(TestReduceMeanOp): + def set_attrs(self): + self.shape = [2, 3, 4, 5, 6, 7] + self.dtype = 'float16' + + class TestReduceMeanOpAxisAll(TestReduceMeanOp): def set_attrs(self): self.axis = [0, 1, 2, 3] +class TestReduceMeanOpAxisAllFP16(TestReduceMeanOp): + def set_attrs(self): + self.axis = [0, 1, 2, 3] + self.dtype = 'float16' + + class TestReduceMeanOpAxisTuple(TestReduceMeanOp): def set_attrs(self): self.axis = (0, 1, 2) +class TestReduceMeanOpAxisTupleFP16(TestReduceMeanOp): + def set_attrs(self): + self.axis = (0, 1, 2) + self.dtype = 'float16' + + class TestReduceMeanOpAxisNegative(TestReduceMeanOp): def set_attrs(self): self.axis = [-2, -1] +class TestReduceMeanOpAxisNegativeFP16(TestReduceMeanOp): + def set_attrs(self): + self.axis = [-2, -1] + self.dtype = 'float16' + + class TestReduceMeanOpKeepdimTrue1(TestReduceMeanOp): def set_attrs(self): self.keepdim = True +class TestReduceMeanOpKeepdimTrue1FP16(TestReduceMeanOp): + def set_attrs(self): + self.keepdim = True + self.dtype = 'float16' + + class TestReduceMeanOpKeepdimTrue2(TestReduceMeanOp): def set_attrs(self): self.axis = [0, 1, 2, 3] self.keepdim = True +class TestReduceMeanOpKeepdimTrue2FP16(TestReduceMeanOp): + def set_attrs(self): + self.axis = [0, 1, 2, 3] + self.keepdim = True + self.dtype = 'float16' + + class TestReduceMeanOpReduceAllTrue(TestReduceMeanOp): def set_attrs(self): self.reduce_all = True +class TestReduceMeanOpReduceAllTrueFP16(TestReduceMeanOp): + def set_attrs(self): + self.reduce_all = True + self.dtype = 'float16' + + class TestMeanAPI(unittest.TestCase): # test paddle.tensor.stat.mean diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index ca01d77c808..6a016e42b5a 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -92,7 +92,8 @@ def mean(x, axis=None, keepdim=False, name=None): return _C_ops.reduce_mean(x, 'dim', axis, 'keep_dim', keepdim, 'reduce_all', reduce_all) - check_variable_and_dtype(x, 'x/input', ['uint16', 'float32', 'float64'], + check_variable_and_dtype(x, 'x/input', + ['uint16', 'float16', 'float32', 'float64'], 'mean/reduce_mean') check_type(axis, 'axis/dim', (int, list, tuple), 'mean/reduce_mean') if isinstance(axis, (list, tuple)): -- GitLab