未验证 提交 5cdd9f2c 编写于 作者: W Wang Xinyu 提交者: GitHub

[AMP OP&Test] Strided slice fp16 and bf16 unitest (#52220)

* stride slice fp16 and bf16 unitest

* fix code style

* add self.dtype
上级 38a477e2
...@@ -28,6 +28,7 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad, ...@@ -28,6 +28,7 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad,
int64_t, int64_t,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -40,5 +41,6 @@ PD_REGISTER_KERNEL(strided_slice_array_grad, ...@@ -40,5 +41,6 @@ PD_REGISTER_KERNEL(strided_slice_array_grad,
int64_t, int64_t,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -28,6 +28,7 @@ PD_REGISTER_KERNEL(strided_slice_raw, ...@@ -28,6 +28,7 @@ PD_REGISTER_KERNEL(strided_slice_raw,
int64_t, int64_t,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -40,5 +41,6 @@ PD_REGISTER_KERNEL(strided_slice_array, ...@@ -40,5 +41,6 @@ PD_REGISTER_KERNEL(strided_slice_array,
int64_t, int64_t,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -29,6 +29,7 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad, ...@@ -29,6 +29,7 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -42,5 +43,6 @@ PD_REGISTER_KERNEL(strided_slice_array_grad, ...@@ -42,5 +43,6 @@ PD_REGISTER_KERNEL(strided_slice_array_grad,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -29,6 +29,7 @@ PD_REGISTER_KERNEL(strided_slice_raw, ...@@ -29,6 +29,7 @@ PD_REGISTER_KERNEL(strided_slice_raw,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -42,5 +43,6 @@ PD_REGISTER_KERNEL(strided_slice_array, ...@@ -42,5 +43,6 @@ PD_REGISTER_KERNEL(strided_slice_array,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
...@@ -52,6 +52,7 @@ PD_REGISTER_KERNEL(strided_slice_grad, ...@@ -52,6 +52,7 @@ PD_REGISTER_KERNEL(strided_slice_grad,
int64_t, int64_t,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -65,6 +66,7 @@ PD_REGISTER_KERNEL(strided_slice_grad, ...@@ -65,6 +66,7 @@ PD_REGISTER_KERNEL(strided_slice_grad,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
#endif #endif
......
...@@ -43,6 +43,7 @@ PD_REGISTER_KERNEL(strided_slice, ...@@ -43,6 +43,7 @@ PD_REGISTER_KERNEL(strided_slice,
int64_t, int64_t,
float, float,
double, double,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...@@ -56,6 +57,7 @@ PD_REGISTER_KERNEL(strided_slice, ...@@ -56,6 +57,7 @@ PD_REGISTER_KERNEL(strided_slice,
float, float,
double, double,
phi::dtype::float16, phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
#endif #endif
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -600,7 +600,7 @@ class TestStridedSliceAPI(unittest.TestCase): ...@@ -600,7 +600,7 @@ class TestStridedSliceAPI(unittest.TestCase):
feed={ feed={
"x": input, "x": input,
'starts': np.array([-3, 0, 2]).astype("int32"), 'starts': np.array([-3, 0, 2]).astype("int32"),
'ends': np.array([3, 2147483648, -1]).astype("int64"), 'ends': np.array([3, 2147483647, -1]).astype("int32"),
'strides': np.array([1, 1, 1]).astype("int32"), 'strides': np.array([1, 1, 1]).astype("int32"),
}, },
fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7], fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7],
...@@ -1011,5 +1011,77 @@ class TestStridedSliceFloat16(unittest.TestCase): ...@@ -1011,5 +1011,77 @@ class TestStridedSliceFloat16(unittest.TestCase):
np.testing.assert_allclose(x_grad_np_fp16, x_grad_np_fp32) np.testing.assert_allclose(x_grad_np_fp16, x_grad_np_fp32)
class TestStrideSliceFP16Op(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'strided_slice'
self.dtype = np.float16
self.python_api = paddle.strided_slice
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides
)
self.inputs = {'Input': self.input.astype(self.dtype)}
self.outputs = {'Out': self.output}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags,
}
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad({'Input'}, 'Out', check_eager=True)
def initTestCase(self):
self.input = np.random.rand(100)
self.axes = [0]
self.starts = [-4]
self.ends = [-3]
self.strides = [1]
self.infer_flags = [1]
class TestStrideSliceBF16Op(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'strided_slice'
self.dtype = np.uint16
self.python_api = paddle.strided_slice
self.output = strided_slice_native_forward(
self.input, self.axes, self.starts, self.ends, self.strides
)
self.inputs = {
'Input': convert_float_to_uint16(self.input.astype(np.float32))
}
self.outputs = {'Out': convert_float_to_uint16(self.output)}
self.attrs = {
'axes': self.axes,
'starts': self.starts,
'ends': self.ends,
'strides': self.strides,
'infer_flags': self.infer_flags,
}
def test_check_output(self):
self.check_output(check_eager=True)
def test_check_grad(self):
self.check_grad({'Input'}, 'Out', check_eager=True)
def initTestCase(self):
self.input = np.random.rand(100)
self.axes = [0]
self.starts = [-4]
self.ends = [-3]
self.strides = [1]
self.infer_flags = [1]
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册