未验证 提交 8a850af6 编写于 作者: Z Zhang Zheng 提交者: GitHub

[AMP OP&Test] Cumprod support fp16 and bf16 (#52919)

上级 8601859e
......@@ -245,8 +245,8 @@ void InclusiveScan(const T *x,
if (outer_dim == 1 && inner_dim == 1) {
if (reverse) {
auto x_reverse_iter = MakeThrustReverseIterator(x + mid_dim);
auto y_reverse_iter = MakeThrustReverseIterator(y + mid_dim);
auto x_reverse_iter = thrust::make_reverse_iterator(x + mid_dim);
auto y_reverse_iter = thrust::make_reverse_iterator(y + mid_dim);
CubInclusiveScan(x_reverse_iter, y_reverse_iter, mid_dim, op, dev_ctx);
} else {
CubInclusiveScan(x, y, mid_dim, op, dev_ctx);
......
......@@ -77,7 +77,7 @@ struct CumprodGradFunctorExceptFirstZero {
first_zero_idx_[outer_idx * inner_dim_ + inner_idx] = 0;
}
x_filled_one_[idx] = should_fill_one ? 1 : x_[idx];
x_filled_one_[idx] = should_fill_one ? static_cast<T>(1) : x_[idx];
}
private:
......@@ -230,7 +230,7 @@ void CumprodGradKernel(const Context &dev_ctx,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(0),
static_cast<T>(0.0f),
cub::Sum(),
/*reverse=*/true,
dev_ctx);
......@@ -270,7 +270,7 @@ void CumprodGradKernel(const Context &dev_ctx,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(1),
static_cast<T>(1.0f),
funcs::MultiplyFunctor<T>(),
/*reverse=*/false,
dev_ctx);
......@@ -292,7 +292,7 @@ void CumprodGradKernel(const Context &dev_ctx,
outer_dim,
mid_dim,
inner_dim,
static_cast<T>(0),
static_cast<T>(0.0f),
cub::Sum(),
/*reverse=*/true,
dev_ctx);
......@@ -319,5 +319,7 @@ PD_REGISTER_KERNEL(cumprod_grad,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -60,5 +60,7 @@ PD_REGISTER_KERNEL(cumprod,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -16,7 +16,7 @@ import random
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
......@@ -64,6 +64,7 @@ class TestCumprod(OpTest):
def init_dtype(self):
self.dtype = np.float64
self.val_dtype = np.float64
def setUp(self):
paddle.enable_static()
......@@ -76,7 +77,9 @@ class TestCumprod(OpTest):
self.attrs = {'dim': None}
def prepare_inputs_outputs_attrs(self, dim, zero_num):
self.x = np.random.random(self.shape).astype(self.dtype) + 0.5
self.x = (
np.random.uniform(0.0, 0.5, self.shape).astype(self.val_dtype) + 0.5
)
if zero_num > 0:
zero_num = min(zero_num, self.x.size)
shape = self.x.shape
......@@ -86,14 +89,18 @@ class TestCumprod(OpTest):
self.x[i] = 0
self.x = np.reshape(self.x, self.shape)
self.out = np.cumprod(self.x, axis=dim)
self.inputs = {'X': self.x}
self.outputs = {'Out': self.out}
if self.dtype == np.uint16:
self.inputs = {'X': convert_float_to_uint16(self.x)}
self.outputs = {'Out': convert_float_to_uint16(self.out)}
else:
self.inputs = {'X': self.x}
self.outputs = {'Out': self.out}
self.attrs = {'dim': dim}
def init_grad_input_output(self, dim):
reshape_x = self.x.reshape(self.x.size)
self.grad_out = np.ones(self.x.size, self.dtype)
self.grad_x = np.zeros(self.x.size, self.dtype)
self.grad_out = np.ones(self.x.size, self.val_dtype)
self.grad_x = np.zeros(self.x.size, self.val_dtype)
out_data = self.out.reshape(self.x.size)
if self.dtype == np.complex128 or self.dtype == np.complex64:
reshape_x = np.conj(reshape_x)
......@@ -101,8 +108,16 @@ class TestCumprod(OpTest):
cumprod_grad(
reshape_x, out_data, self.grad_out, self.grad_x, self.shape, dim
)
self.grad_x = self.grad_x.reshape(self.shape)
self.grad_out = self.grad_out.reshape(self.shape)
if self.dtype == np.uint16:
self.grad_x = convert_float_to_uint16(
self.grad_x.reshape(self.shape)
)
self.grad_out = convert_float_to_uint16(
self.grad_out.reshape(self.shape)
)
else:
self.grad_x = self.grad_x.reshape(self.shape)
self.grad_out = self.grad_out.reshape(self.shape)
# test forward.
def test_check_output(self):
......@@ -129,21 +144,62 @@ class TestCumprod(OpTest):
# test float32 case.
class TestCumprod_float32(TestCumprod):
class TestCumprodFP32Op(TestCumprod):
def init_dtype(self):
self.dtype = np.float32
self.val_dtype = np.float32
class TestCumprodFP16Op(TestCumprod):
def init_dtype(self):
self.dtype = np.float16
self.val_dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestCumprodBF16Op(TestCumprod):
def init_dtype(self):
self.dtype = np.uint16
self.val_dtype = np.float32
# test forward.
def test_check_output(self):
for dim in range(-len(self.shape), len(self.shape)):
for zero_num in self.zero_nums:
self.prepare_inputs_outputs_attrs(dim, zero_num)
self.check_output_with_place(core.CUDAPlace(0))
# test backward.
def test_check_grad(self):
for dim in range(-len(self.shape), len(self.shape)):
for zero_num in self.zero_nums:
self.prepare_inputs_outputs_attrs(dim, zero_num)
self.init_grad_input_output(dim)
self.check_grad_with_place(
core.CUDAPlace(0),
['X'],
'Out',
user_defined_grads=[self.grad_x],
user_defined_grad_outputs=[self.grad_out],
)
# test complex64 case.
class TestCumprod_complex64(TestCumprod):
class TestCumprodComplex64Op(TestCumprod):
def init_dtype(self):
self.dtype = np.complex64
self.val_dtype = np.complex64
# test complex128 case.
class TestCumprod_complex128(TestCumprod):
class TestCumprodComplex128Op(TestCumprod):
def init_dtype(self):
self.dtype = np.complex128
self.val_dtype = np.complex128
# test api.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册