未验证 提交 606939de 编写于 作者: J jiangcheng 提交者: GitHub

Support reduce_sum_op float16 (#32966)

* add reduce_sum_op by add self-kernel

* set all ReduceKernel MPType for accuracy

* add float16 test script which input is integer number

* solve reduce sum float16 check_grad problem

* solve conflict and change test script for CI

* change kernel register for CI

* remove all useless template
上级 02a6d49a
......@@ -237,11 +237,13 @@ struct KronGradElemFunctor<platform::complex<T>> {
const int ndims_;
};
template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE inline T operator()(const T& x) const { return x; }
template <typename U>
HOSTDEVICE inline U operator()(const U& x) const {
return x;
}
};
template <typename DeviceContext, typename T>
......@@ -312,13 +314,13 @@ struct KronGradOpFunctor {
#if defined(__NVCC__) || defined(__HIPCC__)
auto stream = dev_ctx.stream(); // it is a cuda device_context
if (dx) {
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
dout_x, dx, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
TensorReduce<T, T, cub::Sum, IdentityFunctor>(
dout_x, dx, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor(),
stream);
}
if (dy) {
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
dout_y, dy, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor<T>(),
TensorReduce<T, T, cub::Sum, IdentityFunctor>(
dout_y, dy, {1}, static_cast<T>(0), cub::Sum(), IdentityFunctor(),
stream);
}
#else
......
......@@ -34,11 +34,13 @@ namespace operators {
using framework::Tensor;
template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE inline T operator()(const T& x) const { return x; }
template <typename U>
HOSTDEVICE inline U operator()(const U& x) const {
return x;
}
};
template <typename DeviceContext, typename T>
......@@ -47,9 +49,9 @@ void ReduceSumForMatmulGrad(const Tensor* input, Tensor* output,
const paddle::framework::ExecutionContext& ctx) {
#if defined(__NVCC__) || defined(__HIPCC__)
auto stream = ctx.cuda_device_context().stream();
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
*input, output, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor<T>(), stream);
TensorReduce<T, T, cub::Sum, IdentityFunctor>(*input, output, reduce_dims,
static_cast<T>(0), cub::Sum(),
IdentityFunctor(), stream);
#else
ReduceKernelFunctor<DeviceContext, T, ops::SumFunctor>(
input, output, reduce_dims, true, false, ctx)
......
......@@ -31,7 +31,11 @@ namespace operators {
template <typename T>
struct DivideFunctor {
HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
HOSTDEVICE inline T operator()(const T& x) const { return x * n_inv; }
template <typename U>
HOSTDEVICE inline U operator()(const U& x) const {
return x * static_cast<U>(n_inv);
}
private:
T n_inv;
......
......@@ -31,6 +31,7 @@ namespace cub = hipcub;
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
namespace paddle {
namespace operators {
......@@ -66,39 +67,66 @@ struct Array {
T data_[ElementCount];
};
// reduce the 1d array to one element
template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim>
__global__ void ReduceKernel1D(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, MPType init,
int reduce_num) {
int thread_id = blockIdx.x * blockDim.x + threadIdx.x;
typedef cub::BlockReduce<MPType, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
MPType local_data = init;
for (int i = thread_id; i < reduce_num; i += gridDim.x * blockDim.x) {
local_data = static_cast<MPType>(
reducer(local_data, static_cast<MPType>(transformer(x[i]))));
}
__syncthreads();
local_data = BlockReduce(temp_storage).Reduce(local_data, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x] = static_cast<Ty>(local_data);
}
}
// reduce the last axis of 2d array
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
int BlockDim>
template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim>
__global__ void ReduceKernel2D(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init,
TransformOp transformer, MPType init,
int reduce_num) {
__shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;
__shared__
typename cub::BlockReduce<MPType, BlockDim>::TempStorage temp_storage;
int idx_x = blockIdx.x * reduce_num;
int idx_y = threadIdx.x;
Ty reduce_var = init;
MPType reduce_var = init;
for (int idx_y = threadIdx.x; idx_y < reduce_num; idx_y += BlockDim)
reduce_var =
reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x + idx_y])));
reducer(reduce_var, static_cast<MPType>(transformer(x[idx_x + idx_y])));
__syncthreads();
reduce_var =
cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);
reduce_var = cub::BlockReduce<MPType, BlockDim>(temp_storage)
.Reduce(reduce_var, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x] = reduce_var;
y[blockIdx.x] = static_cast<Ty>(reduce_var);
}
}
template <typename Tx, typename Ty, typename ReduceOp, typename TransformOp,
int BlockDim, int Rank, int ReduceRank>
template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim, int Rank, int ReduceRank>
__global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,
TransformOp transformer, Ty init, int reduce_num,
Array<int, Rank> x_strides,
TransformOp transformer, MPType init,
int reduce_num, Array<int, Rank> x_strides,
Array<int, ReduceRank> reduce_dim,
Array<int, ReduceRank> reduce_strides,
Array<int, Rank - ReduceRank> left_dim,
Array<int, Rank - ReduceRank> left_strides) {
__shared__ typename cub::BlockReduce<Ty, BlockDim>::TempStorage temp_storage;
__shared__
typename cub::BlockReduce<MPType, BlockDim>::TempStorage temp_storage;
Array<int, Rank> sub_index;
int left_idx = blockIdx.x;
for (int i = 0; i < Rank - ReduceRank; ++i) {
......@@ -114,7 +142,7 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,
int idx_x = 0;
for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]);
Ty reduce_var = static_cast<Ty>(transformer(x[idx_x]));
MPType reduce_var = static_cast<MPType>(transformer(x[idx_x]));
for (int i = threadIdx.x + BlockDim; i < reduce_num; i += BlockDim) {
int reduce_idx = i;
......@@ -125,16 +153,16 @@ __global__ void ReduceKernel(const Tx* x, Ty* y, ReduceOp reducer,
int idx_x = 0;
for (int k = 0; k < Rank; ++k) idx_x += (sub_index[k] * x_strides[k]);
reduce_var = static_cast<Ty>(
reducer(reduce_var, static_cast<Ty>(transformer(x[idx_x]))));
reduce_var = static_cast<MPType>(
reducer(reduce_var, static_cast<MPType>(transformer(x[idx_x]))));
}
__syncthreads();
reduce_var =
cub::BlockReduce<Ty, BlockDim>(temp_storage).Reduce(reduce_var, reducer);
reduce_var = cub::BlockReduce<MPType, BlockDim>(temp_storage)
.Reduce(reduce_var, reducer);
if (threadIdx.x == 0) {
y[blockIdx.x] = reduce_var;
y[blockIdx.x] = static_cast<Ty>(reduce_var);
}
}
......@@ -192,6 +220,53 @@ static inline void CheckReduceRankIsValid(int reduce_rank, int rank) {
}
}
template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim>
typename std::enable_if<!std::is_same<Tx, paddle::platform::float16>::value,
void>::type
LaunchCubReduceKernel(const Tx* x_data, Ty* y_data,
const platform::Place& place, const ReduceOp& reducer,
const TransformOp& transformer, const MPType& init,
int reduce_num, gpuStream_t stream) {
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transformer);
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, init, stream);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}), place);
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, init, stream);
}
template <typename Tx, typename MPType, typename Ty, typename ReduceOp,
typename TransformOp, int BlockDim>
typename std::enable_if<std::is_same<Tx, paddle::platform::float16>::value,
void>::type
LaunchCubReduceKernel(const Tx* x_data, Ty* y_data,
const platform::Place& place, const ReduceOp& reducer,
const TransformOp& transformer, const MPType& init,
int reduce_num, gpuStream_t stream) {
int element_per_block = BlockDim * 10;
int block_per_grid = (reduce_num + element_per_block - 1) / element_per_block;
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<MPType>(
framework::make_ddim(
{static_cast<int64_t>(block_per_grid * sizeof(MPType))}),
place);
// each block reduce number to interim result
ReduceKernel1D<Tx, MPType, MPType, ReduceOp, TransformOp,
BlockDim><<<block_per_grid, BlockDim, 0, stream>>>(
x_data, temp_storage, reducer, transformer, init, reduce_num);
// reduce all number to final result
ReduceKernel1D<MPType, MPType, Ty, ReduceOp, TransformOp,
BlockDim><<<1, BlockDim, 0, stream>>>(
temp_storage, y_data, reducer, transformer, init, block_per_grid);
}
template <typename Tx, typename Ty, int BlockDim, typename ReduceOp,
typename TransformOp>
static void TensorReduceImpl(
......@@ -201,45 +276,40 @@ static void TensorReduceImpl(
const std::vector<int>& reduce_dim, const std::vector<int>& reduce_strides,
const std::vector<int>& left_dim, const std::vector<int>& left_strides,
gpuStream_t stream) {
using MPType = typename details::MPTypeTrait<Ty>::Type;
MPType init_mp = static_cast<MPType>(init);
#define CUB_RANK_CASE(i, ...) \
case i: { \
constexpr auto kRank = i; \
switch (reduce_rank) { __VA_ARGS__; } \
} break
#define CUB_REDUCE_RANK_CASE(i, ...) \
case i: { \
constexpr auto kReduceRank = i; \
ReduceKernel<Tx, Ty, ReduceOp, TransformOp, BlockDim, kRank, \
kReduceRank><<<left_num, BlockDim, 0, stream>>>( \
x_data, y_data, reducer, transformer, init, reduce_num, \
Array<int, kRank>::From(x_strides), \
Array<int, kReduceRank>::From(reduce_dim), \
Array<int, kReduceRank>::From(reduce_strides), \
Array<int, kRank - kReduceRank>::From(left_dim), \
Array<int, kRank - kReduceRank>::From(left_strides)); \
#define CUB_REDUCE_RANK_CASE(i, ...) \
case i: { \
constexpr auto kReduceRank = i; \
ReduceKernel<Tx, MPType, Ty, ReduceOp, TransformOp, BlockDim, kRank, \
kReduceRank><<<left_num, BlockDim, 0, stream>>>( \
x_data, y_data, reducer, transformer, init_mp, reduce_num, \
Array<int, kRank>::From(x_strides), \
Array<int, kReduceRank>::From(reduce_dim), \
Array<int, kReduceRank>::From(reduce_strides), \
Array<int, kRank - kReduceRank>::From(left_dim), \
Array<int, kRank - kReduceRank>::From(left_strides)); \
} break
int rank = x_strides.size();
int reduce_rank = reduce_strides.size();
if (rank == reduce_rank) {
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(
x_data, transformer);
size_t temp_storage_bytes = 0;
cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, init, stream);
framework::Tensor tmp;
auto* temp_storage = tmp.mutable_data<uint8_t>(
framework::make_ddim({static_cast<int64_t>(temp_storage_bytes)}),
place);
cub::DeviceReduce::Reduce(temp_storage, temp_storage_bytes, trans_x, y_data,
reduce_num, reducer, init, stream);
LaunchCubReduceKernel<Tx, MPType, Ty, ReduceOp, TransformOp, BlockDim>(
x_data, y_data, place, reducer, transformer, init_mp, reduce_num,
stream);
return;
}
if (rank == 2 && reduce_rank == 1 && reduce_dim[0] == 1) {
ReduceKernel2D<Tx, Ty, ReduceOp, TransformOp,
ReduceKernel2D<Tx, MPType, Ty, ReduceOp, TransformOp,
BlockDim><<<left_num, BlockDim, 0, stream>>>(
x_data, y_data, reducer, transformer, init, reduce_num);
x_data, y_data, reducer, transformer, init_mp, reduce_num);
return;
}
/*
......@@ -366,8 +436,7 @@ void TensorReduce(const framework::Tensor& x, framework::Tensor* y,
#undef CUB_BLOCK_DIM_CASE
}
template <typename Tx, typename ReduceOp,
template <typename, typename> class TransformOp>
template <typename Tx, typename ReduceOp, template <typename> class TransformOp>
struct TensorReduceFunctor {
const framework::Tensor& x;
framework::Tensor* y;
......@@ -389,9 +458,9 @@ struct TensorReduceFunctor {
void apply() const {
const Ty& init_cast = static_cast<Ty>(init);
TensorReduce<Tx, Ty, ReduceOp, TransformOp<Tx, Ty>>(
x, y, origin_reduce_dims, init_cast, reducer, TransformOp<Tx, Ty>(),
stream);
TensorReduce<Tx, Ty, ReduceOp, TransformOp<Ty>>(x, y, origin_reduce_dims,
init_cast, reducer,
TransformOp<Ty>(), stream);
}
};
......
......@@ -115,6 +115,8 @@ REGISTER_OP_CPU_KERNEL(
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, double,
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::SumFunctor>,
......@@ -133,6 +135,7 @@ using CPUReduceSumGradKernel =
REGISTER_OP_CPU_KERNEL(
reduce_sum_grad, CPUReduceSumGradKernel<bool>,
CPUReduceSumGradKernel<float>, CPUReduceSumGradKernel<double>,
CPUReduceSumGradKernel<paddle::platform::float16>,
CPUReduceSumGradKernel<int>, CPUReduceSumGradKernel<int64_t>,
CPUReduceSumGradKernel<paddle::platform::complex<float>>,
CPUReduceSumGradKernel<paddle::platform::complex<double>>);
......@@ -18,12 +18,13 @@
namespace paddle {
namespace operators {
template <typename Tx, typename Ty = Tx>
template <typename Tout>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x);
template <typename U>
HOSTDEVICE inline Tout operator()(const U& x) const {
return static_cast<Tout>(x);
}
};
......@@ -62,9 +63,9 @@ class ReduceSumKernel : public framework::OpKernel<T> {
*input, output, reduce_dims, static_cast<double>(0.0), cub::Sum(),
stream));
} else {
TensorReduce<T, T, cub::Sum, IdentityFunctor<T, T>>(
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
*input, output, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor<T, T>(), stream);
IdentityFunctor<T>(), stream);
}
}
};
......@@ -74,7 +75,8 @@ class ReduceSumKernel : public framework::OpKernel<T> {
REGISTER_OP_CUDA_KERNEL(
reduce_sum, ops::ReduceSumKernel<bool>, ops::ReduceSumKernel<float>,
ops::ReduceSumKernel<double>, ops::ReduceSumKernel<int>,
ops::ReduceSumKernel<double>,
ops::ReduceSumKernel<paddle::platform::float16>, ops::ReduceSumKernel<int>,
ops::ReduceSumKernel<int64_t>,
ops::ReduceSumKernel<paddle::platform::complex<float>>,
ops::ReduceSumKernel<paddle::platform::complex<double>>);
......@@ -23,6 +23,7 @@ using CUDAReduceSumGradKernel =
REGISTER_OP_CUDA_KERNEL(
reduce_sum_grad, CUDAReduceSumGradKernel<bool>,
CUDAReduceSumGradKernel<float>, CUDAReduceSumGradKernel<double>,
CUDAReduceSumGradKernel<paddle::platform::float16>,
CUDAReduceSumGradKernel<int>, CUDAReduceSumGradKernel<int64_t>,
CUDAReduceSumGradKernel<paddle::platform::complex<float>>,
CUDAReduceSumGradKernel<paddle::platform::complex<double>>);
......@@ -20,11 +20,13 @@
namespace paddle {
namespace operators {
template <typename T>
struct IdentityFunctor {
HOSTDEVICE explicit inline IdentityFunctor() {}
HOSTDEVICE inline T operator()(const T& x) const { return x; }
template <typename U>
HOSTDEVICE inline U operator()(const U& x) const {
return x;
}
};
template <typename DeviceContext, typename T>
......@@ -45,9 +47,9 @@ class TraceCUDAKernel : public framework::OpKernel<T> {
auto stream = context.cuda_device_context().stream();
std::vector<int> reduce_dims;
reduce_dims.push_back(out->dims().size());
TensorReduce<T, T, cub::Sum, IdentityFunctor<T>>(
TensorReduce<T, T, cub::Sum, IdentityFunctor>(
diag, out, reduce_dims, static_cast<T>(0), cub::Sum(),
IdentityFunctor<T>(), stream);
IdentityFunctor(), stream);
}
}
};
......
......@@ -4424,7 +4424,8 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
if dim == None or dim == [] or len(dim) == len(input.shape) else False
}
check_variable_and_dtype(
input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_sum')
input, 'input', ['float16', 'float32', 'float64', 'int32', 'int64'],
'reduce_sum')
helper = LayerHelper('reduce_sum', **locals())
out = helper.create_variable_for_type_inference(dtype=helper.input_dtype())
helper.append_op(
......
......@@ -37,6 +37,56 @@ class TestSumOp(OpTest):
self.check_grad(['X'], 'Out')
class TestSumOp_fp16(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {
'X': np.random.uniform(0, 0.1, (5, 6, 10)).astype("float16")
}
self.attrs = {'dim': [0, 1, 2]}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
self.gradient = self.calc_gradient()
def test_check_output(self):
self.check_output()
def calc_gradient(self):
x = self.inputs["X"]
grad = np.ones(x.shape, dtype=x.dtype)
return grad,
def test_check_grad(self):
self.check_grad(['X'], 'Out', user_defined_grads=self.gradient)
class TestSumOp_fp16_withInt(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
self.inputs = {
# ref to https://en.wikipedia.org/wiki/Half-precision_floating-point_format
# Precision limitations on integer values between 0 and 2048 can be exactly represented
'X': np.random.randint(0, 30, (10, 10)).astype("float16")
}
self.attrs = {'dim': [0, 1]}
self.outputs = {
'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']))
}
self.gradient = self.calc_gradient()
def test_check_output(self):
self.check_output()
def calc_gradient(self):
x = self.inputs["X"]
grad = np.ones(x.shape, dtype=x.dtype)
return grad,
def test_check_grad(self):
self.check_grad(['X'], 'Out', user_defined_grads=self.gradient)
class TestSumOp5D(OpTest):
def setUp(self):
self.op_type = "reduce_sum"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册