未验证 提交 c91b1b91 编写于 作者: T thunder95 提交者: GitHub

PaddlePaddle Hackathon 3 No.45 & 46】:为 Paddle cumsum和logcumsumexp 支持 float16 数据类型 (#45952)

上级 a3436672
...@@ -27,11 +27,25 @@ namespace cub = hipcub; ...@@ -27,11 +27,25 @@ namespace cub = hipcub;
#endif #endif
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
template <typename T>
class CumTypeTrait {
public:
using Type = T;
};
template <>
class CumTypeTrait<phi::dtype::float16> {
public:
using Type = __half;
};
template <typename T, int BLOCK_SIZE> template <typename T, int BLOCK_SIZE>
__device__ void BlockReverse( __device__ void BlockReverse(
const T* idata, T* odata, int src_base, int dst_base, int valid_item) { const T* idata, T* odata, int src_base, int dst_base, int valid_item) {
...@@ -39,7 +53,7 @@ __device__ void BlockReverse( ...@@ -39,7 +53,7 @@ __device__ void BlockReverse(
int tx = threadIdx.x; int tx = threadIdx.x;
int offset = tx; int offset = tx;
T src_data = 0; T src_data = static_cast<T>(0);
int src_offset = BLOCK_SIZE - offset - 1; int src_offset = BLOCK_SIZE - offset - 1;
if (src_offset < valid_item) { if (src_offset < valid_item) {
src_data = idata[src_base + src_offset]; src_data = idata[src_base + src_offset];
...@@ -160,14 +174,18 @@ __global__ void BlockScanKernel(T* d_out, ...@@ -160,14 +174,18 @@ __global__ void BlockScanKernel(T* d_out,
int scan_size, int scan_size,
bool exclusive, bool exclusive,
Op op) { Op op) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
// Specialize BlockLoad, BlockStore, and BlockRadixSort collective types // Specialize BlockLoad, BlockStore, and BlockRadixSort collective types
typedef cub:: typedef cub::
BlockLoad<T, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE> BlockLoad<MT, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_TRANSPOSE>
BlockLoadT; BlockLoadT;
typedef cub:: typedef cub::BlockStore<MT,
BlockStore<T, BLOCK_THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_TRANSPOSE> BLOCK_THREADS,
ITEMS_PER_THREAD,
cub::BLOCK_STORE_TRANSPOSE>
BlockStoreT; BlockStoreT;
typedef cub::BlockScan<T, BLOCK_THREADS> BlockScanT; typedef cub::BlockScan<MT, BLOCK_THREADS> BlockScanT;
// Allocate type-safe, repurposable shared memory for collectives // Allocate type-safe, repurposable shared memory for collectives
__shared__ union { __shared__ union {
typename BlockLoadT::TempStorage load; typename BlockLoadT::TempStorage load;
...@@ -176,8 +194,7 @@ __global__ void BlockScanKernel(T* d_out, ...@@ -176,8 +194,7 @@ __global__ void BlockScanKernel(T* d_out,
} temp_storage; } temp_storage;
int bx = blockIdx.x; int bx = blockIdx.x;
BlockPrefixCallbackOp<MT, Op> prefix_op(Identity<MT, Op>::value, op);
BlockPrefixCallbackOp<T, Op> prefix_op(Identity<T, Op>::value, op);
// Obtain this block's segment of consecutive keys (blocked across threads) // Obtain this block's segment of consecutive keys (blocked across threads)
int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD; int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
...@@ -192,7 +209,7 @@ __global__ void BlockScanKernel(T* d_out, ...@@ -192,7 +209,7 @@ __global__ void BlockScanKernel(T* d_out,
int offset = block_offset + bx * scan_size; int offset = block_offset + bx * scan_size;
T thread_keys[ITEMS_PER_THREAD]; MT thread_keys[ITEMS_PER_THREAD];
BlockLoadT(temp_storage.load) BlockLoadT(temp_storage.load)
.Load(d_in + offset, thread_keys, valid_item, 0); .Load(d_in + offset, thread_keys, valid_item, 0);
...@@ -241,17 +258,22 @@ void ScanKernel(const Context& dev_ctx, ...@@ -241,17 +258,22 @@ void ScanKernel(const Context& dev_ctx,
// Use thrust for parallel acceleration when the input size is equal to the // Use thrust for parallel acceleration when the input size is equal to the
// length of the ‘axis’ dimension. // length of the ‘axis’ dimension.
if (std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) { if (!std::is_same<T, phi::dtype::float16>::value &&
std::is_same<Op, cub::Sum>::value && size == out_dims[axis]) {
#ifdef __HIPCC__ #ifdef __HIPCC__
const auto& policy = thrust::hip::par.on(dev_ctx.stream()); const auto& policy = thrust::hip::par.on(dev_ctx.stream());
#else #else
const auto& policy = thrust::cuda::par.on(dev_ctx.stream()); const auto& policy = thrust::cuda::par.on(dev_ctx.stream());
#endif #endif
using CumType = typename CumTypeTrait<T>::Type;
CumType* out_data_ptr = reinterpret_cast<CumType*>(out_data);
const CumType* in_data_ptr = reinterpret_cast<const CumType*>(in_data);
if (reverse) { if (reverse) {
thrust::reverse_iterator<thrust::device_ptr<const T>> reversed_in( thrust::reverse_iterator<thrust::device_ptr<const CumType>> reversed_in(
thrust::device_pointer_cast(in_data) + size); thrust::device_pointer_cast(in_data_ptr) + size);
thrust::reverse_iterator<thrust::device_ptr<T>> reversed_out( thrust::reverse_iterator<thrust::device_ptr<CumType>> reversed_out(
thrust::device_pointer_cast(out_data) + size); thrust::device_pointer_cast(out_data_ptr) + size);
if (exclusive) { if (exclusive) {
thrust::exclusive_scan( thrust::exclusive_scan(
policy, reversed_in, reversed_in + size, reversed_out); policy, reversed_in, reversed_in + size, reversed_out);
...@@ -261,11 +283,14 @@ void ScanKernel(const Context& dev_ctx, ...@@ -261,11 +283,14 @@ void ScanKernel(const Context& dev_ctx,
} }
} else { } else {
if (exclusive) { if (exclusive) {
thrust::exclusive_scan(policy, in_data, in_data + size, out_data); thrust::exclusive_scan(
policy, in_data_ptr, in_data_ptr + size, out_data_ptr);
} else { } else {
thrust::inclusive_scan(policy, in_data, in_data + size, out_data); thrust::inclusive_scan(
policy, in_data_ptr, in_data_ptr + size, out_data_ptr);
} }
} }
return; return;
} }
...@@ -305,7 +330,6 @@ void ScanKernel(const Context& dev_ctx, ...@@ -305,7 +330,6 @@ void ScanKernel(const Context& dev_ctx,
int outer_size = height / scan_size; int outer_size = height / scan_size;
int inner_size = width; int inner_size = width;
// Consider the size of shared memory, here block size is 128 // Consider the size of shared memory, here block size is 128
dim3 scan_grid(outer_size, inner_size); dim3 scan_grid(outer_size, inner_size);
dim3 reverse_grid = scan_grid; dim3 reverse_grid = scan_grid;
if (reverse) { if (reverse) {
...@@ -380,6 +404,7 @@ void LogcumsumexpKernel(const Context& dev_ctx, ...@@ -380,6 +404,7 @@ void LogcumsumexpKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cumsum, PD_REGISTER_KERNEL(cumsum,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -392,3 +417,23 @@ PD_REGISTER_KERNEL(cumsum, ...@@ -392,3 +417,23 @@ PD_REGISTER_KERNEL(cumsum,
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
logcumsumexp, GPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {} logcumsumexp, GPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {}
#else
PD_REGISTER_KERNEL(cumsum,
GPU,
ALL_LAYOUT,
phi::CumsumKernel,
float,
double,
int16_t,
int,
int64_t,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(logcumsumexp,
GPU,
ALL_LAYOUT,
phi::LogcumsumexpKernel,
float,
double,
phi::dtype::float16) {}
#endif
...@@ -20,9 +20,19 @@ ...@@ -20,9 +20,19 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logcumsumexp_grad_impl.h" #include "paddle/phi/kernels/impl/logcumsumexp_grad_impl.h"
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(logcumsumexp_grad, PD_REGISTER_KERNEL(logcumsumexp_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::LogcumsumexpGradKernel, phi::LogcumsumexpGradKernel,
float, float,
double) {} double) {}
#else
PD_REGISTER_KERNEL(logcumsumexp_grad,
GPU,
ALL_LAYOUT,
phi::LogcumsumexpGradKernel,
phi::dtype::float16,
float,
double) {}
#endif
...@@ -12,9 +12,11 @@ ...@@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <limits> #pragma once
#include <limits>
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cum_kernel.h" #include "paddle/phi/kernels/cum_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
...@@ -55,32 +57,38 @@ void LogcumsumexpGradKernel(const Context& dev_ctx, ...@@ -55,32 +57,38 @@ void LogcumsumexpGradKernel(const Context& dev_ctx,
auto eigen_d_out = EigenVector<T>::Flatten(d_out); auto eigen_d_out = EigenVector<T>::Flatten(d_out);
auto& place = *dev_ctx.eigen_device(); auto& place = *dev_ctx.eigen_device();
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
DenseTensor output_pos; DenseTensor output_pos;
output_pos.Resize(d_out.dims()); output_pos.Resize(d_out.dims());
dev_ctx.template Alloc<T>(&output_pos); dev_ctx.template Alloc<MT>(&output_pos);
auto eigen_output_pos = EigenVector<T>::Flatten(output_pos); auto eigen_output_pos = EigenVector<MT>::Flatten(output_pos);
DenseTensor output_neg; DenseTensor output_neg;
output_neg.Resize(d_out.dims()); output_neg.Resize(d_out.dims());
dev_ctx.template Alloc<T>(&output_neg); dev_ctx.template Alloc<MT>(&output_neg);
auto eigen_output_neg = EigenVector<T>::Flatten(output_neg); auto eigen_output_neg = EigenVector<MT>::Flatten(output_neg);
DenseTensor tmp; DenseTensor tmp;
tmp.Resize(d_out.dims()); tmp.Resize(d_out.dims());
dev_ctx.template Alloc<T>(&tmp); dev_ctx.template Alloc<MT>(&tmp);
auto eigen_tmp = EigenVector<T>::Flatten(tmp); auto eigen_tmp = EigenVector<MT>::Flatten(tmp);
eigen_tmp.device(place) = eigen_tmp.device(place) =
eigen_d_out.unaryExpr(LogGradPositiveFunctor<T>()) - eigen_out; eigen_d_out.template cast<MT>().unaryExpr(LogGradPositiveFunctor<MT>()) -
LogcumsumexpKernel<T, Context>( eigen_out.template cast<MT>();
LogcumsumexpKernel<MT, Context>(
dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_pos); dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_pos);
eigen_output_pos.device(place) = (eigen_output_pos + eigen_x).exp(); auto out_pos = eigen_output_pos + eigen_x.template cast<MT>();
eigen_output_pos.device(place) = out_pos.exp();
eigen_tmp.device(place) = eigen_tmp.device(place) =
eigen_d_out.unaryExpr(LogGradNegativeFunctor<T>()) - eigen_out; eigen_d_out.template cast<MT>().unaryExpr(LogGradNegativeFunctor<MT>()) -
LogcumsumexpKernel<T, Context>( eigen_out.template cast<MT>();
LogcumsumexpKernel<MT, Context>(
dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_neg); dev_ctx, tmp, axis, flatten, exclusive, reverse, &output_neg);
eigen_output_neg.device(place) = (eigen_output_neg + eigen_x).exp(); auto out_neg = eigen_output_neg + eigen_x.template cast<MT>();
eigen_output_neg.device(place) = out_neg.exp();
auto eigen_d_x = EigenVector<T>::Flatten(*d_x); auto eigen_d_x = EigenVector<T>::Flatten(*d_x);
eigen_d_x.device(place) = eigen_output_pos - eigen_output_neg; eigen_d_x.device(place) =
(eigen_output_pos - eigen_output_neg).template cast<T>();
} }
} // namespace phi } // namespace phi
...@@ -199,6 +199,32 @@ class TestSumOp7(OpTest): ...@@ -199,6 +199,32 @@ class TestSumOp7(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestCumsumFP16(unittest.TestCase):
def check_main(self, x_np, dtype):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
x.stop_gradient = False
y = paddle.cumsum(x, dtype=dtype)
x_g = paddle.grad(y, [x])
y_np = y.numpy().astype('float32')
x_g_np = x_g[0].numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
np.random.seed(20)
x_np = np.random.random([10, 12])
y_np_1, x_g_np_1 = self.check_main(x_np, 'float16')
y_np_2, x_g_np_2 = self.check_main(x_np, 'float32')
np.testing.assert_allclose(y_np_1, y_np_2, rtol=1e-03)
np.testing.assert_allclose(x_g_np_1, x_g_np_2, rtol=1e-03)
class TestSumOpExclusive1(OpTest): class TestSumOpExclusive1(OpTest):
def setUp(self): def setUp(self):
...@@ -289,6 +315,24 @@ class TestSumOpExclusive5(OpTest): ...@@ -289,6 +315,24 @@ class TestSumOpExclusive5(OpTest):
self.check_output() self.check_output()
class TestSumOpExclusiveFP16(OpTest):
def setUp(self):
self.op_type = "cumsum"
self.attrs = {'axis': 2, "exclusive": True, "dtype": "float16"}
a = np.random.random((4, 5, 3096)).astype("float64")
self.inputs = {'X': a}
self.outputs = {
'Out':
np.concatenate((np.zeros(
(4, 5, 1), dtype=np.float64), a[:, :, :-1].cumsum(axis=2)),
axis=2)
}
def test_check_output(self):
self.check_output()
class TestSumOpReverseExclusive(OpTest): class TestSumOpReverseExclusive(OpTest):
def setUp(self): def setUp(self):
......
...@@ -210,6 +210,8 @@ class BaseTestCases: ...@@ -210,6 +210,8 @@ class BaseTestCases:
input, attrs = self.input_and_attrs() input, attrs = self.input_and_attrs()
self.inputs = {'X': input} self.inputs = {'X': input}
self.attrs = attrs self.attrs = attrs
if "dtype" in attrs:
del attrs["dtype"]
self.outputs = {'Out': np_logcumsumexp(input, **attrs)} self.outputs = {'Out': np_logcumsumexp(input, **attrs)}
def test_check_output(self): def test_check_output(self):
...@@ -264,5 +266,36 @@ class TestLogcumsumexpOp4(BaseTestCases.BaseOpTest): ...@@ -264,5 +266,36 @@ class TestLogcumsumexpOp4(BaseTestCases.BaseOpTest):
} }
class TestLogcumsumexpFP16(unittest.TestCase):
def check_main(self, x_np, dtype, axis=None):
paddle.disable_static()
x = paddle.to_tensor(x_np.astype(dtype))
x.stop_gradient = False
y = paddle.logcumsumexp(x, dtype=dtype, axis=axis)
x_g = paddle.grad(y, [x])
y_np = y.numpy().astype('float32')
x_g_np = x_g[0].numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np
def test_main(self):
if not paddle.is_compiled_with_cuda():
return
np.random.seed(20)
x_np = np.random.random([10, 12])
y_np_1, x_g_np_1 = self.check_main(x_np, 'float16')
y_np_2, x_g_np_2 = self.check_main(x_np, 'float32')
np.testing.assert_allclose(y_np_1, y_np_2, rtol=1e-03)
np.testing.assert_allclose(x_g_np_1, x_g_np_2, rtol=1e-03)
y_np_1, x_g_np_1 = self.check_main(x_np, 'float16', axis=1)
y_np_2, x_g_np_2 = self.check_main(x_np, 'float32', axis=1)
np.testing.assert_allclose(y_np_1, y_np_2, rtol=1e-03)
np.testing.assert_allclose(x_g_np_1, x_g_np_2, rtol=2e-03)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -3173,7 +3173,7 @@ def cumsum(x, axis=None, dtype=None, name=None): ...@@ -3173,7 +3173,7 @@ def cumsum(x, axis=None, dtype=None, name=None):
Args: Args:
x (Tensor): The input tensor needed to be cumsumed. x (Tensor): The input tensor needed to be cumsumed.
axis (int, optional): The dimension to accumulate along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array. axis (int, optional): The dimension to accumulate along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array.
dtype (str, optional): The data type of the output tensor, can be float32, float64, int32, int64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. dtype (str, optional): The data type of the output tensor, can be float16, float32, float64, int32, int64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -3246,7 +3246,7 @@ def logcumsumexp(x, axis=None, dtype=None, name=None): ...@@ -3246,7 +3246,7 @@ def logcumsumexp(x, axis=None, dtype=None, name=None):
Args: Args:
x (Tensor): The input tensor. x (Tensor): The input tensor.
axis (int, optional): The dimension to do the operation along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array. axis (int, optional): The dimension to do the operation along. -1 means the last dimension. The default (None) is to compute the cumsum over the flattened array.
dtype (str, optional): The data type of the output tensor, can be float32, float64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None. dtype (str, optional): The data type of the output tensor, can be float16, float32, float64. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. The default value is None.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
...@@ -3295,7 +3295,8 @@ def logcumsumexp(x, axis=None, dtype=None, name=None): ...@@ -3295,7 +3295,8 @@ def logcumsumexp(x, axis=None, dtype=None, name=None):
return _legacy_C_ops.logcumsumexp(x, 'axis', axis, 'flatten', return _legacy_C_ops.logcumsumexp(x, 'axis', axis, 'flatten',
flatten) flatten)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], "logcumsumexp") check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
"logcumsumexp")
helper = LayerHelper('logcumsumexp', **locals()) helper = LayerHelper('logcumsumexp', **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册