未验证 提交 51f4291c 编写于 作者: M ming1753 提交者: GitHub

strided_slice grad add fp16 support (#45504)

上级 ad96fe2c
......@@ -28,6 +28,7 @@ PD_REGISTER_KERNEL(strided_slice_raw_grad,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -40,5 +41,6 @@ PD_REGISTER_KERNEL(strided_slice_array_grad,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -990,5 +990,43 @@ class TestStridedSliceTensorArray(unittest.TestCase):
self.create_case(Net27(input_size=112, array_size=13))
@unittest.skipIf(not fluid.core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestStridedSliceFloat16(unittest.TestCase):
def init_test_case(self):
self.op_type = 'strided_slice'
self.input_shape = [3, 3, 3, 6, 7, 8]
self.axes = [0, 1, 2, 3, 4, 5]
self.starts = [1, 0, 0, 0, 1, 2]
self.ends = [2, 2, 3, 1, 2, 8]
self.strides = [1, 1, 1, 1, 1, 2]
self.infer_flags = [1, 1, 1, 1, 1]
def check_main(self, x_np, dtype):
paddle.disable_static()
x_np = x_np.astype(dtype)
x = paddle.to_tensor(x_np)
x.stop_gradient = False
output = strided_slice_native_forward(x, self.axes, self.starts,
self.ends, self.strides)
x_grad = paddle.grad(output, x)
output_np = output[0].numpy().astype('float32')
x_grad_np = x_grad[0].numpy().astype('float32')
paddle.enable_static()
return output_np, x_grad_np
def test_check(self):
self.init_test_case()
x_np = np.random.random(self.input_shape).astype("float16")
output_np_fp16, x_grad_np_fp16 = self.check_main(x_np, 'float16')
output_np_fp32, x_grad_np_fp32 = self.check_main(x_np, 'float32')
np.testing.assert_allclose(output_np_fp16, output_np_fp32)
np.testing.assert_allclose(x_grad_np_fp16, x_grad_np_fp32)
if __name__ == "__main__":
unittest.main()
......@@ -3580,7 +3580,7 @@ def strided_slice(x, axes, starts, ends, strides, name=None):
result = [ [2], ]
Args:
x (Tensor): An N-D ``Tensor``. The data type is ``bool``, ``float32``, ``float64``, ``int32`` or ``int64``.
x (Tensor): An N-D ``Tensor``. The data type is ``bool``, ``float16``, ``float32``, ``float64``, ``int32`` or ``int64``.
axes (list|tuple): The data type is ``int32`` . Axes that `starts` and `ends` apply to.
It's optional. If it is not provides, it will be treated as :math:`[0,1,...,len(starts)-1]`.
starts (list|tuple|Tensor): The data type is ``int32`` . If ``starts`` is a list or tuple, the elements of it should be integers or Tensors with shape [1]. If ``starts`` is an Tensor, it should be an 1-D Tensor. It represents starting indices of corresponding axis in ``axes``.
......@@ -3619,8 +3619,8 @@ def strided_slice(x, axes, starts, ends, strides, name=None):
helper = LayerHelper('strided_slice', **locals())
check_variable_and_dtype(x, 'x',
['bool', 'float32', 'float64', 'int32', 'int64'],
check_variable_and_dtype(
x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
'strided_slice')
check_type(axes, 'axes', (list, tuple), 'strided_slice')
check_type(starts, 'starts', (list, tuple, Variable), 'strided_slice')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册