From 606939de76af62afc1d4170b6b2e53e4ba743a74 Mon Sep 17 00:00:00 2001 From: jiangcheng Date: Tue, 15 Jun 2021 10:56:55 +0800 Subject: [PATCH] Support reduce_sum_op float16 (#32966) * add reduce_sum_op by add self-kernel * set all ReduceKernel MPType for accuracy * add float16 test script which input is integer number * solve reduce sum float16 check_grad problem * solve conflict and change test script for CI * change kernel register for CI * remove all useless template --- paddle/fluid/operators/kron_op.h | 14 +- paddle/fluid/operators/matmul_v2_op.h | 12 +- paddle/fluid/operators/pool_op.h | 6 +- .../fluid/operators/reduce_ops/cub_reduce.h | 167 +++++++++++++----- .../operators/reduce_ops/reduce_sum_op.cc | 3 + .../operators/reduce_ops/reduce_sum_op.cu | 14 +- .../reduce_ops/reduce_sum_op.part.cu | 1 + paddle/fluid/operators/trace_op.cu | 10 +- python/paddle/fluid/layers/nn.py | 3 +- .../fluid/tests/unittests/test_reduce_op.py | 50 ++++++ 10 files changed, 208 insertions(+), 72 deletions(-) diff --git a/paddle/fluid/operators/kron_op.h b/paddle/fluid/operators/kron_op.h index 6c3bad4e1bd..ea2050fe8e6 100644 --- a/paddle/fluid/operators/kron_op.h +++ b/paddle/fluid/operators/kron_op.h @@ -237,11 +237,13 @@ struct KronGradElemFunctor> { const int ndims_; }; -template struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} - HOSTDEVICE inline T operator()(const T& x) const { return x; } + template + HOSTDEVICE inline U operator()(const U& x) const { + return x; + } }; template @@ -312,13 +314,13 @@ struct KronGradOpFunctor { #if defined(__NVCC__) || defined(__HIPCC__) auto stream = dev_ctx.stream(); // it is a cuda device_context if (dx) { - TensorReduce>( - dout_x, dx, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), + TensorReduce( + dout_x, dx, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), stream); } if (dy) { - TensorReduce>( - dout_y, dy, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), + TensorReduce( + dout_y, dy, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), stream); } #else diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 6061679b288..5b114f38199 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -34,11 +34,13 @@ namespace operators { using framework::Tensor; -template struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} - HOSTDEVICE inline T operator()(const T& x) const { return x; } + template + HOSTDEVICE inline U operator()(const U& x) const { + return x; + } }; template @@ -47,9 +49,9 @@ void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output, const paddle::framework::ExecutionContext& ctx) { #if defined(__NVCC__) || defined(__HIPCC__) auto stream = ctx.cuda_device_context().stream(); - TensorReduce>( - *input, output, reduce_dims, static_cast(0), cub::Sum(), - IdentityFunctor(), stream); + TensorReduce(*input, output, reduce_dims, + static_cast(0), cub::Sum(), + IdentityFunctor(), stream); #else ReduceKernelFunctor( input, output, reduce_dims, true, false, ctx) diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index 9117b1b95ed..e84c92d9a16 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -31,7 +31,11 @@ namespace operators { template struct DivideFunctor { HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {} - HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; } + + template + HOSTDEVICE inline U operator()(const U& x) const { + return x * static_cast(n_inv); + } private: T n_inv; diff --git a/paddle/fluid/operators/reduce_ops/cub_reduce.h b/paddle/fluid/operators/reduce_ops/cub_reduce.h index 9e1aed5dde4..0aab680e13d 100644 --- a/paddle/fluid/operators/reduce_ops/cub_reduce.h +++ b/paddle/fluid/operators/reduce_ops/cub_reduce.h @@ -31,6 +31,7 @@ namespace cub = hipcub; #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" namespace paddle { namespace operators { @@ -66,39 +67,66 @@ struct Array { T data_[ElementCount]; }; +// reduce the 1d array to one element +template +__global__ void ReduceKernel1D(const Tx* x, Ty* y, ReduceOp reducer, + TransformOp transformer, MPType init, + int reduce_num) { + int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + MPType local_data = init; + for (int i = thread_id; i < reduce_num; i += gridDim.x * blockDim.x) { + local_data = static_cast( + reducer(local_data, static_cast(transformer(x[i])))); + } + __syncthreads(); + + local_data = BlockReduce(temp_storage).Reduce(local_data, reducer); + + if (threadIdx.x == 0) { + y[blockIdx.x] = static_cast(local_data); + } +} + // reduce the last axis of 2d array -template +template __global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, Ty init, + TransformOp transformer, MPType init, int reduce_num) { - __shared__ typename cub::BlockReduce::TempStorage temp_storage; + __shared__ + typename cub::BlockReduce::TempStorage temp_storage; int idx_x = blockIdx.x * reduce_num; int idx_y = threadIdx.x; - Ty reduce_var = init; + MPType reduce_var = init; for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim) reduce_var = - reducer(reduce_var, static_cast(transformer(x[idx_x + idx_y]))); + reducer(reduce_var, static_cast(transformer(x[idx_x + idx_y]))); __syncthreads(); - reduce_var = - cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); + reduce_var = cub::BlockReduce(temp_storage) + .Reduce(reduce_var, reducer); if (threadIdx.x == 0) { - y[blockIdx.x] = reduce_var; + y[blockIdx.x] = static_cast(reduce_var); } } -template +template __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, - TransformOp transformer, Ty init, int reduce_num, - Array x_strides, + TransformOp transformer, MPType init, + int reduce_num, Array x_strides, Array reduce_dim, Array reduce_strides, Array left_dim, Array left_strides) { - __shared__ typename cub::BlockReduce::TempStorage temp_storage; + __shared__ + typename cub::BlockReduce::TempStorage temp_storage; Array sub_index; int left_idx = blockIdx.x; for (int i = 0; i < Rank - ReduceRank; ++i) { @@ -114,7 +142,7 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, int idx_x = 0; for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); - Ty reduce_var = static_cast(transformer(x[idx_x])); + MPType reduce_var = static_cast(transformer(x[idx_x])); for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) { int reduce_idx = i; @@ -125,16 +153,16 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer, int idx_x = 0; for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]); - reduce_var = static_cast( - reducer(reduce_var, static_cast(transformer(x[idx_x])))); + reduce_var = static_cast( + reducer(reduce_var, static_cast(transformer(x[idx_x])))); } __syncthreads(); - reduce_var = - cub::BlockReduce(temp_storage).Reduce(reduce_var, reducer); + reduce_var = cub::BlockReduce(temp_storage) + .Reduce(reduce_var, reducer); if (threadIdx.x == 0) { - y[blockIdx.x] = reduce_var; + y[blockIdx.x] = static_cast(reduce_var); } } @@ -192,6 +220,53 @@ static inline void CheckReduceRankIsValid(int reduce_rank, int rank) { } } +template +typename std::enable_if::value, + void>::type +LaunchCubReduceKernel(const Tx* x_data, Ty* y_data, + const platform::Place& place, const ReduceOp& reducer, + const TransformOp& transformer, const MPType& init, + int reduce_num, gpuStream_t stream) { + cub::TransformInputIterator trans_x(x_data, + transformer); + size_t temp_storage_bytes = 0; + cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, + reduce_num, reducer, init, stream); + framework::Tensor tmp; + auto* temp_storage = tmp.mutable_data( + framework::make_ddim({static_cast(temp_storage_bytes)}), place); + cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, + reduce_num, reducer, init, stream); +} + +template +typename std::enable_if::value, + void>::type +LaunchCubReduceKernel(const Tx* x_data, Ty* y_data, + const platform::Place& place, const ReduceOp& reducer, + const TransformOp& transformer, const MPType& init, + int reduce_num, gpuStream_t stream) { + int element_per_block = BlockDim * 10; + int block_per_grid = (reduce_num + element_per_block - 1) / element_per_block; + + framework::Tensor tmp; + auto* temp_storage = tmp.mutable_data( + framework::make_ddim( + {static_cast(block_per_grid * sizeof(MPType))}), + place); + + // each block reduce number to interim result + ReduceKernel1D<<>>( + x_data, temp_storage, reducer, transformer, init, reduce_num); + // reduce all number to final result + ReduceKernel1D<<<1, BlockDim, 0, stream>>>( + temp_storage, y_data, reducer, transformer, init, block_per_grid); +} + template static void TensorReduceImpl( @@ -201,45 +276,40 @@ static void TensorReduceImpl( const std::vector& reduce_dim, const std::vector& reduce_strides, const std::vector& left_dim, const std::vector& left_strides, gpuStream_t stream) { + using MPType = typename details::MPTypeTrait::Type; + MPType init_mp = static_cast(init); + #define CUB_RANK_CASE(i, ...) \ case i: { \ constexpr auto kRank = i; \ switch (reduce_rank) { __VA_ARGS__; } \ } break -#define CUB_REDUCE_RANK_CASE(i, ...) \ - case i: { \ - constexpr auto kReduceRank = i; \ - ReduceKernel<<>>( \ - x_data, y_data, reducer, transformer, init, reduce_num, \ - Array::From(x_strides), \ - Array::From(reduce_dim), \ - Array::From(reduce_strides), \ - Array::From(left_dim), \ - Array::From(left_strides)); \ +#define CUB_REDUCE_RANK_CASE(i, ...) \ + case i: { \ + constexpr auto kReduceRank = i; \ + ReduceKernel<<>>( \ + x_data, y_data, reducer, transformer, init_mp, reduce_num, \ + Array::From(x_strides), \ + Array::From(reduce_dim), \ + Array::From(reduce_strides), \ + Array::From(left_dim), \ + Array::From(left_strides)); \ } break int rank = x_strides.size(); int reduce_rank = reduce_strides.size(); if (rank == reduce_rank) { - cub::TransformInputIterator trans_x( - x_data, transformer); - size_t temp_storage_bytes = 0; - cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data, - reduce_num, reducer, init, stream); - framework::Tensor tmp; - auto* temp_storage = tmp.mutable_data( - framework::make_ddim({static_cast(temp_storage_bytes)}), - place); - cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data, - reduce_num, reducer, init, stream); + LaunchCubReduceKernel( + x_data, y_data, place, reducer, transformer, init_mp, reduce_num, + stream); return; } if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) { - ReduceKernel2D<<>>( - x_data, y_data, reducer, transformer, init, reduce_num); + x_data, y_data, reducer, transformer, init_mp, reduce_num); return; } /* @@ -366,8 +436,7 @@ void TensorReduce(const framework::Tensor& x, framework::Tensor* y, #undef CUB_BLOCK_DIM_CASE } -template class TransformOp> +template class TransformOp> struct TensorReduceFunctor { const framework::Tensor& x; framework::Tensor* y; @@ -389,9 +458,9 @@ struct TensorReduceFunctor { void apply() const { const Ty& init_cast = static_cast(init); - TensorReduce>( - x, y, origin_reduce_dims, init_cast, reducer, TransformOp(), - stream); + TensorReduce>(x, y, origin_reduce_dims, + init_cast, reducer, + TransformOp(), stream); } }; diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index 74e7db649d5..9e4cc8e213c 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -115,6 +115,8 @@ REGISTER_OP_CPU_KERNEL( ops::SumFunctor>, ops::ReduceKernel, + ops::ReduceKernel, ops::ReduceKernel, ops::ReduceKernel, @@ -133,6 +135,7 @@ using CPUReduceSumGradKernel = REGISTER_OP_CPU_KERNEL( reduce_sum_grad, CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel, + CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel, CPUReduceSumGradKernel>, CPUReduceSumGradKernel>); diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu index dd16ca4e393..efbafe4aa8c 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cu @@ -18,12 +18,13 @@ namespace paddle { namespace operators { -template +template struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} - HOSTDEVICE inline Ty operator()(const Tx& x) const { - return static_cast(x); + template + HOSTDEVICE inline Tout operator()(const U& x) const { + return static_cast(x); } }; @@ -62,9 +63,9 @@ class ReduceSumKernel : public framework::OpKernel { *input, output, reduce_dims, static_cast(0.0), cub::Sum(), stream)); } else { - TensorReduce>( + TensorReduce>( *input, output, reduce_dims, static_cast(0), cub::Sum(), - IdentityFunctor(), stream); + IdentityFunctor(), stream); } } }; @@ -74,7 +75,8 @@ class ReduceSumKernel : public framework::OpKernel { REGISTER_OP_CUDA_KERNEL( reduce_sum, ops::ReduceSumKernel, ops::ReduceSumKernel, - ops::ReduceSumKernel, ops::ReduceSumKernel, + ops::ReduceSumKernel, + ops::ReduceSumKernel, ops::ReduceSumKernel, ops::ReduceSumKernel, ops::ReduceSumKernel>, ops::ReduceSumKernel>); diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu index 230bae0cdd4..419b8ce2765 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.part.cu @@ -23,6 +23,7 @@ using CUDAReduceSumGradKernel = REGISTER_OP_CUDA_KERNEL( reduce_sum_grad, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, + CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel, CUDAReduceSumGradKernel>, CUDAReduceSumGradKernel>); diff --git a/paddle/fluid/operators/trace_op.cu b/paddle/fluid/operators/trace_op.cu index 6798521c8f7..336c1c40832 100644 --- a/paddle/fluid/operators/trace_op.cu +++ b/paddle/fluid/operators/trace_op.cu @@ -20,11 +20,13 @@ namespace paddle { namespace operators { -template struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} - HOSTDEVICE inline T operator()(const T& x) const { return x; } + template + HOSTDEVICE inline U operator()(const U& x) const { + return x; + } }; template @@ -45,9 +47,9 @@ class TraceCUDAKernel : public framework::OpKernel { auto stream = context.cuda_device_context().stream(); std::vector reduce_dims; reduce_dims.push_back(out->dims().size()); - TensorReduce>( + TensorReduce( diag, out, reduce_dims, static_cast(0), cub::Sum(), - IdentityFunctor(), stream); + IdentityFunctor(), stream); } } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e02edb72ce1..7e50646c0c4 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4424,7 +4424,8 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): if dim == None or dim == [] or len(dim) == len(input.shape) else False } check_variable_and_dtype( - input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_sum') + input, 'input', ['float16', 'float32', 'float64', 'int32', 'int64'], + 'reduce_sum') helper = LayerHelper('reduce_sum', **locals()) out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) helper.append_op( diff --git a/python/paddle/fluid/tests/unittests/test_reduce_op.py b/python/paddle/fluid/tests/unittests/test_reduce_op.py index 912df563fcd..2dd5bcb8113 100644 --- a/python/paddle/fluid/tests/unittests/test_reduce_op.py +++ b/python/paddle/fluid/tests/unittests/test_reduce_op.py @@ -37,6 +37,56 @@ class TestSumOp(OpTest): self.check_grad(['X'], 'Out') +class TestSumOp_fp16(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = { + 'X': np.random.uniform(0, 0.1, (5, 6, 10)).astype("float16") + } + self.attrs = {'dim': [0, 1, 2]} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) + } + self.gradient = self.calc_gradient() + + def test_check_output(self): + self.check_output() + + def calc_gradient(self): + x = self.inputs["X"] + grad = np.ones(x.shape, dtype=x.dtype) + return grad, + + def test_check_grad(self): + self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) + + +class TestSumOp_fp16_withInt(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.inputs = { + # ref to https://en.wikipedia.org/wiki/Half-precision_floating-point_format + # Precision limitations on integer values between 0 and 2048 can be exactly represented + 'X': np.random.randint(0, 30, (10, 10)).astype("float16") + } + self.attrs = {'dim': [0, 1]} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) + } + self.gradient = self.calc_gradient() + + def test_check_output(self): + self.check_output() + + def calc_gradient(self): + x = self.inputs["X"] + grad = np.ones(x.shape, dtype=x.dtype) + return grad, + + def test_check_grad(self): + self.check_grad(['X'], 'Out', user_defined_grads=self.gradient) + + class TestSumOp5D(OpTest): def setUp(self): self.op_type = "reduce_sum" -- GitLab