未验证 提交 871d2d36 编写于 作者: D denglianbin 提交者: GitHub

【Hackathon No.47】为 Paddle cross 算子实现 float16 数据类型支持 (#50924)

* finish task

* add static_check and fix unittest.

* add int32/64

* Update test_cross_op.py

---------
Co-authored-by: NZhang Ting <Douyaer2020@qq.com>
上级 628ddcf3
......@@ -38,14 +38,32 @@ __global__ void CrossGrad(const T* x,
auto pos1 = offset + 1 * stride;
auto pos2 = offset + 2 * stride;
out_dx[pos0] = out[pos2] * y[pos1] - out[pos1] * y[pos2];
out_dy[pos0] = out[pos1] * x[pos2] - out[pos2] * x[pos1];
out_dx[pos1] = out[pos0] * y[pos2] - out[pos2] * y[pos0];
out_dy[pos1] = out[pos2] * x[pos0] - out[pos0] * x[pos2];
out_dx[pos2] = out[pos1] * y[pos0] - out[pos0] * y[pos1];
out_dy[pos2] = out[pos0] * x[pos1] - out[pos1] * x[pos0];
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType x_pos0_mp = static_cast<MPType>(x[pos0]);
MPType x_pos1_mp = static_cast<MPType>(x[pos1]);
MPType x_pos2_mp = static_cast<MPType>(x[pos2]);
MPType y_pos0_mp = static_cast<MPType>(y[pos0]);
MPType y_pos1_mp = static_cast<MPType>(y[pos1]);
MPType y_pos2_mp = static_cast<MPType>(y[pos2]);
MPType out_pos0_mp = static_cast<MPType>(out[pos0]);
MPType out_pos1_mp = static_cast<MPType>(out[pos1]);
MPType out_pos2_mp = static_cast<MPType>(out[pos2]);
out_dx[pos0] =
static_cast<T>(out_pos2_mp * y_pos1_mp - out_pos1_mp * y_pos2_mp);
out_dy[pos0] =
static_cast<T>(out_pos1_mp * x_pos2_mp - out_pos2_mp * x_pos1_mp);
out_dx[pos1] =
static_cast<T>(out_pos0_mp * y_pos2_mp - out_pos2_mp * y_pos0_mp);
out_dy[pos1] =
static_cast<T>(out_pos2_mp * x_pos0_mp - out_pos0_mp * x_pos2_mp);
out_dx[pos2] =
static_cast<T>(out_pos1_mp * y_pos0_mp - out_pos0_mp * y_pos1_mp);
out_dy[pos2] =
static_cast<T>(out_pos0_mp * x_pos1_mp - out_pos1_mp * x_pos0_mp);
}
}
......@@ -172,6 +190,7 @@ PD_REGISTER_KERNEL(cross_grad,
GPU,
ALL_LAYOUT,
phi::CrossGradKernel,
phi::dtype::float16,
float,
double,
int,
......
......@@ -16,6 +16,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/index_calculator.h"
......@@ -36,9 +37,18 @@ __global__ void Cross(const T* x,
auto pos1 = offset + 1 * stride;
auto pos2 = offset + 2 * stride;
out[pos0] = x[pos1] * y[pos2] - x[pos2] * y[pos1];
out[pos1] = x[pos2] * y[pos0] - x[pos0] * y[pos2];
out[pos2] = x[pos0] * y[pos1] - x[pos1] * y[pos0];
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType x_pos0_mp = static_cast<MPType>(x[pos0]);
MPType x_pos1_mp = static_cast<MPType>(x[pos1]);
MPType x_pos2_mp = static_cast<MPType>(x[pos2]);
MPType y_pos0_mp = static_cast<MPType>(y[pos0]);
MPType y_pos1_mp = static_cast<MPType>(y[pos1]);
MPType y_pos2_mp = static_cast<MPType>(y[pos2]);
out[pos0] = static_cast<T>(x_pos1_mp * y_pos2_mp - x_pos2_mp * y_pos1_mp);
out[pos1] = static_cast<T>(x_pos2_mp * y_pos0_mp - x_pos0_mp * y_pos2_mp);
out[pos2] = static_cast<T>(x_pos0_mp * y_pos1_mp - x_pos1_mp * y_pos0_mp);
}
}
......@@ -153,5 +163,12 @@ void CrossKernel(const Context& dev_ctx,
}
} // namespace phi
PD_REGISTER_KERNEL(
cross, GPU, ALL_LAYOUT, phi::CrossKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(cross,
GPU,
ALL_LAYOUT,
phi::CrossKernel,
phi::dtype::float16,
float,
double,
int,
int64_t) {}
......@@ -65,6 +65,18 @@ class TestCrossOpCase1(TestCrossOp):
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}
class TestCrossFP16Op(TestCrossOp):
def initTestCase(self):
self.shape = (2048, 3)
self.dtype = np.float16
def init_output(self):
z_list = []
for i in range(2048):
z_list.append(np.cross(self.inputs['X'][i], self.inputs['Y'][i]))
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}
class TestCrossAPI(unittest.TestCase):
def input_data(self):
self.data_x = np.array(
......
......@@ -1334,8 +1334,8 @@ def cross(x, y, axis=9, name=None):
If `axis` is not given, it defaults to the first axis found with the length 3.
Args:
x (Tensor): The first input tensor.
y (Tensor): The second input tensor.
x (Tensor): The first input tensor, the data type is float16, float32, float64, int32, int64.
y (Tensor): The second input tensor, the data type is float16, float32, float64, int32, int64.
axis (int, optional): The axis along which to compute the cross product. It defaults to be 9 which indicates using the first axis found with the length 3.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
......@@ -1368,6 +1368,18 @@ def cross(x, y, axis=9, name=None):
axis = K_DEFAULT_DIM if axis is None else axis
return _C_ops.cross(x, y, axis)
else:
check_variable_and_dtype(
x,
'x',
['float16', 'float32', 'float64', "int32", "int64"],
'cross',
)
check_variable_and_dtype(
y,
'y',
['float16', 'float32', 'float64', "int32", "int64"],
'cross',
)
helper = LayerHelper("cross", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
attrs = dict()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册