diff --git a/paddle/phi/kernels/funcs/inclusive_scan.h b/paddle/phi/kernels/funcs/inclusive_scan.h index 550a2f6bea8bdd4d5a55a91d15b943b3a3bbce0b..265febd306f334b246f6902838d4db0160bdcb29 100644 --- a/paddle/phi/kernels/funcs/inclusive_scan.h +++ b/paddle/phi/kernels/funcs/inclusive_scan.h @@ -245,8 +245,8 @@ void InclusiveScan(const T *x, if (outer_dim == 1 && inner_dim == 1) { if (reverse) { - auto x_reverse_iter = MakeThrustReverseIterator(x + mid_dim); - auto y_reverse_iter = MakeThrustReverseIterator(y + mid_dim); + auto x_reverse_iter = thrust::make_reverse_iterator(x + mid_dim); + auto y_reverse_iter = thrust::make_reverse_iterator(y + mid_dim); CubInclusiveScan(x_reverse_iter, y_reverse_iter, mid_dim, op, dev_ctx); } else { CubInclusiveScan(x, y, mid_dim, op, dev_ctx); diff --git a/paddle/phi/kernels/gpu/cumprod_grad_kernel.cu b/paddle/phi/kernels/gpu/cumprod_grad_kernel.cu index f869ea0669f2edc1f454ec2ee99891aebe89266a..fdd9b4ba4991461c753bb07dd267c91bfcc013ed 100644 --- a/paddle/phi/kernels/gpu/cumprod_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cumprod_grad_kernel.cu @@ -77,7 +77,7 @@ struct CumprodGradFunctorExceptFirstZero { first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = 0; } - x_filled_one_[idx] = should_fill_one ? 1 : x_[idx]; + x_filled_one_[idx] = should_fill_one ? static_cast(1) : x_[idx]; } private: @@ -230,7 +230,7 @@ void CumprodGradKernel(const Context &dev_ctx, outer_dim, mid_dim, inner_dim, - static_cast(0), + static_cast(0.0f), cub::Sum(), /*reverse=*/true, dev_ctx); @@ -270,7 +270,7 @@ void CumprodGradKernel(const Context &dev_ctx, outer_dim, mid_dim, inner_dim, - static_cast(1), + static_cast(1.0f), funcs::MultiplyFunctor(), /*reverse=*/false, dev_ctx); @@ -292,7 +292,7 @@ void CumprodGradKernel(const Context &dev_ctx, outer_dim, mid_dim, inner_dim, - static_cast(0), + static_cast(0.0f), cub::Sum(), /*reverse=*/true, dev_ctx); @@ -319,5 +319,7 @@ PD_REGISTER_KERNEL(cumprod_grad, double, int, int64_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/cumprod_kernel.cu b/paddle/phi/kernels/gpu/cumprod_kernel.cu index bff5a09cd4e239c604b531f1e443dee54203a55d..d637477e1d2e799dadcd2f8ef7e903e92cbaccda 100644 --- a/paddle/phi/kernels/gpu/cumprod_kernel.cu +++ b/paddle/phi/kernels/gpu/cumprod_kernel.cu @@ -60,5 +60,7 @@ PD_REGISTER_KERNEL(cumprod, double, int, int64_t, + phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/python/paddle/fluid/tests/unittests/test_cumprod_op.py b/python/paddle/fluid/tests/unittests/test_cumprod_op.py index 2fc9a8835282f4ccf4b465814a7841deaa668d33..65b3c8da65870c8c5db1a1d9ba871ef05151445e 100644 --- a/python/paddle/fluid/tests/unittests/test_cumprod_op.py +++ b/python/paddle/fluid/tests/unittests/test_cumprod_op.py @@ -16,7 +16,7 @@ import random import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle.fluid import core @@ -64,6 +64,7 @@ class TestCumprod(OpTest): def init_dtype(self): self.dtype = np.float64 + self.val_dtype = np.float64 def setUp(self): paddle.enable_static() @@ -76,7 +77,9 @@ class TestCumprod(OpTest): self.attrs = {'dim': None} def prepare_inputs_outputs_attrs(self, dim, zero_num): - self.x = np.random.random(self.shape).astype(self.dtype) + 0.5 + self.x = ( + np.random.uniform(0.0, 0.5, self.shape).astype(self.val_dtype) + 0.5 + ) if zero_num > 0: zero_num = min(zero_num, self.x.size) shape = self.x.shape @@ -86,14 +89,18 @@ class TestCumprod(OpTest): self.x[i] = 0 self.x = np.reshape(self.x, self.shape) self.out = np.cumprod(self.x, axis=dim) - self.inputs = {'X': self.x} - self.outputs = {'Out': self.out} + if self.dtype == np.uint16: + self.inputs = {'X': convert_float_to_uint16(self.x)} + self.outputs = {'Out': convert_float_to_uint16(self.out)} + else: + self.inputs = {'X': self.x} + self.outputs = {'Out': self.out} self.attrs = {'dim': dim} def init_grad_input_output(self, dim): reshape_x = self.x.reshape(self.x.size) - self.grad_out = np.ones(self.x.size, self.dtype) - self.grad_x = np.zeros(self.x.size, self.dtype) + self.grad_out = np.ones(self.x.size, self.val_dtype) + self.grad_x = np.zeros(self.x.size, self.val_dtype) out_data = self.out.reshape(self.x.size) if self.dtype == np.complex128 or self.dtype == np.complex64: reshape_x = np.conj(reshape_x) @@ -101,8 +108,16 @@ class TestCumprod(OpTest): cumprod_grad( reshape_x, out_data, self.grad_out, self.grad_x, self.shape, dim ) - self.grad_x = self.grad_x.reshape(self.shape) - self.grad_out = self.grad_out.reshape(self.shape) + if self.dtype == np.uint16: + self.grad_x = convert_float_to_uint16( + self.grad_x.reshape(self.shape) + ) + self.grad_out = convert_float_to_uint16( + self.grad_out.reshape(self.shape) + ) + else: + self.grad_x = self.grad_x.reshape(self.shape) + self.grad_out = self.grad_out.reshape(self.shape) # test forward. def test_check_output(self): @@ -129,21 +144,62 @@ class TestCumprod(OpTest): # test float32 case. -class TestCumprod_float32(TestCumprod): +class TestCumprodFP32Op(TestCumprod): def init_dtype(self): self.dtype = np.float32 + self.val_dtype = np.float32 + + +class TestCumprodFP16Op(TestCumprod): + def init_dtype(self): + self.dtype = np.float16 + self.val_dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestCumprodBF16Op(TestCumprod): + def init_dtype(self): + self.dtype = np.uint16 + self.val_dtype = np.float32 + + # test forward. + def test_check_output(self): + for dim in range(-len(self.shape), len(self.shape)): + for zero_num in self.zero_nums: + self.prepare_inputs_outputs_attrs(dim, zero_num) + self.check_output_with_place(core.CUDAPlace(0)) + + # test backward. + def test_check_grad(self): + for dim in range(-len(self.shape), len(self.shape)): + for zero_num in self.zero_nums: + self.prepare_inputs_outputs_attrs(dim, zero_num) + self.init_grad_input_output(dim) + self.check_grad_with_place( + core.CUDAPlace(0), + ['X'], + 'Out', + user_defined_grads=[self.grad_x], + user_defined_grad_outputs=[self.grad_out], + ) # test complex64 case. -class TestCumprod_complex64(TestCumprod): +class TestCumprodComplex64Op(TestCumprod): def init_dtype(self): self.dtype = np.complex64 + self.val_dtype = np.complex64 # test complex128 case. -class TestCumprod_complex128(TestCumprod): +class TestCumprodComplex128Op(TestCumprod): def init_dtype(self): self.dtype = np.complex128 + self.val_dtype = np.complex128 # test api.