diff --git a/paddle/phi/kernels/gpu/cross_grad_kernel.cu b/paddle/phi/kernels/gpu/cross_grad_kernel.cu index 4a27a8e2b05f97c7c5031e31ea7963d3304efc3a..b3316ea875b9060a7d0a73c86d7bd7fd8517760f 100644 --- a/paddle/phi/kernels/gpu/cross_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_grad_kernel.cu @@ -38,14 +38,32 @@ __global__ void CrossGrad(const T* x, auto pos1 = offset + 1 * stride; auto pos2 = offset + 2 * stride; - out_dx[pos0] = out[pos2] * y[pos1] - out[pos1] * y[pos2]; - out_dy[pos0] = out[pos1] * x[pos2] - out[pos2] * x[pos1]; - - out_dx[pos1] = out[pos0] * y[pos2] - out[pos2] * y[pos0]; - out_dy[pos1] = out[pos2] * x[pos0] - out[pos0] * x[pos2]; - - out_dx[pos2] = out[pos1] * y[pos0] - out[pos0] * y[pos1]; - out_dy[pos2] = out[pos0] * x[pos1] - out[pos1] * x[pos0]; + using MPType = typename phi::dtype::MPTypeTrait::Type; + + MPType x_pos0_mp = static_cast(x[pos0]); + MPType x_pos1_mp = static_cast(x[pos1]); + MPType x_pos2_mp = static_cast(x[pos2]); + MPType y_pos0_mp = static_cast(y[pos0]); + MPType y_pos1_mp = static_cast(y[pos1]); + MPType y_pos2_mp = static_cast(y[pos2]); + MPType out_pos0_mp = static_cast(out[pos0]); + MPType out_pos1_mp = static_cast(out[pos1]); + MPType out_pos2_mp = static_cast(out[pos2]); + + out_dx[pos0] = + static_cast(out_pos2_mp * y_pos1_mp - out_pos1_mp * y_pos2_mp); + out_dy[pos0] = + static_cast(out_pos1_mp * x_pos2_mp - out_pos2_mp * x_pos1_mp); + + out_dx[pos1] = + static_cast(out_pos0_mp * y_pos2_mp - out_pos2_mp * y_pos0_mp); + out_dy[pos1] = + static_cast(out_pos2_mp * x_pos0_mp - out_pos0_mp * x_pos2_mp); + + out_dx[pos2] = + static_cast(out_pos1_mp * y_pos0_mp - out_pos0_mp * y_pos1_mp); + out_dy[pos2] = + static_cast(out_pos0_mp * x_pos1_mp - out_pos1_mp * x_pos0_mp); } } @@ -172,6 +190,7 @@ PD_REGISTER_KERNEL(cross_grad, GPU, ALL_LAYOUT, phi::CrossGradKernel, + phi::dtype::float16, float, double, int, diff --git a/paddle/phi/kernels/gpu/cross_kernel.cu b/paddle/phi/kernels/gpu/cross_kernel.cu index 875c043188d4ce1721bcfab29f28c98e25b05c27..60623cb8e3d747063cea5c20e36660ab8849853b 100644 --- a/paddle/phi/kernels/gpu/cross_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_kernel.cu @@ -16,6 +16,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/index_calculator.h" @@ -36,9 +37,18 @@ __global__ void Cross(const T* x, auto pos1 = offset + 1 * stride; auto pos2 = offset + 2 * stride; - out[pos0] = x[pos1] * y[pos2] - x[pos2] * y[pos1]; - out[pos1] = x[pos2] * y[pos0] - x[pos0] * y[pos2]; - out[pos2] = x[pos0] * y[pos1] - x[pos1] * y[pos0]; + using MPType = typename phi::dtype::MPTypeTrait::Type; + + MPType x_pos0_mp = static_cast(x[pos0]); + MPType x_pos1_mp = static_cast(x[pos1]); + MPType x_pos2_mp = static_cast(x[pos2]); + MPType y_pos0_mp = static_cast(y[pos0]); + MPType y_pos1_mp = static_cast(y[pos1]); + MPType y_pos2_mp = static_cast(y[pos2]); + + out[pos0] = static_cast(x_pos1_mp * y_pos2_mp - x_pos2_mp * y_pos1_mp); + out[pos1] = static_cast(x_pos2_mp * y_pos0_mp - x_pos0_mp * y_pos2_mp); + out[pos2] = static_cast(x_pos0_mp * y_pos1_mp - x_pos1_mp * y_pos0_mp); } } @@ -153,5 +163,12 @@ void CrossKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL( - cross, GPU, ALL_LAYOUT, phi::CrossKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(cross, + GPU, + ALL_LAYOUT, + phi::CrossKernel, + phi::dtype::float16, + float, + double, + int, + int64_t) {} diff --git a/python/paddle/fluid/tests/unittests/test_cross_op.py b/python/paddle/fluid/tests/unittests/test_cross_op.py index 29bdf93cf1c7ba6ce671c910796e7c0965e36e6c..fbfc992d6a06a9a0b1df54084be4a4f6d246478e 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_op.py +++ b/python/paddle/fluid/tests/unittests/test_cross_op.py @@ -65,6 +65,18 @@ class TestCrossOpCase1(TestCrossOp): self.outputs = {'Out': np.array(z_list).reshape(self.shape)} +class TestCrossFP16Op(TestCrossOp): + def initTestCase(self): + self.shape = (2048, 3) + self.dtype = np.float16 + + def init_output(self): + z_list = [] + for i in range(2048): + z_list.append(np.cross(self.inputs['X'][i], self.inputs['Y'][i])) + self.outputs = {'Out': np.array(z_list).reshape(self.shape)} + + class TestCrossAPI(unittest.TestCase): def input_data(self): self.data_x = np.array( diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index c9b47df4d1c5e7e1ed090c942d2e8ec69d4940cd..d308735a949f1452fbd16b244a5213bacd075f44 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -1334,8 +1334,8 @@ def cross(x, y, axis=9, name=None): If `axis` is not given, it defaults to the first axis found with the length 3. Args: - x (Tensor): The first input tensor. - y (Tensor): The second input tensor. + x (Tensor): The first input tensor, the data type is float16, float32, float64, int32, int64. + y (Tensor): The second input tensor, the data type is float16, float32, float64, int32, int64. axis (int, optional): The axis along which to compute the cross product. It defaults to be 9 which indicates using the first axis found with the length 3. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -1368,6 +1368,18 @@ def cross(x, y, axis=9, name=None): axis = K_DEFAULT_DIM if axis is None else axis return _C_ops.cross(x, y, axis) else: + check_variable_and_dtype( + x, + 'x', + ['float16', 'float32', 'float64', "int32", "int64"], + 'cross', + ) + check_variable_and_dtype( + y, + 'y', + ['float16', 'float32', 'float64', "int32", "int64"], + 'cross', + ) helper = LayerHelper("cross", **locals()) out = helper.create_variable_for_type_inference(x.dtype) attrs = dict()