未验证 提交 a2cbc81a 编写于 作者: D denglianbin 提交者: GitHub

【Hackathon No.50】为 Paddle lerp 算子实现 float16 数据类型支持 (#50925)

* finish task

* fix error

* pre-commit fix code style

* add unittest.

* change unittest.

* delete unittest case.
上级 f547ee92
......@@ -18,6 +18,7 @@
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/broadcast_tensors_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
......@@ -35,16 +36,18 @@ __global__ void LerpGradKernelImpl(const T* weight,
const int out_size,
const int x_size,
const int y_size) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
CUDA_KERNEL_LOOP_TYPE(idx, out_size, int64_t) {
T temp_dx = weight[idx] * dout[idx];
MPType temp_dx =
static_cast<MPType>(weight[idx]) * static_cast<MPType>(dout[idx]);
if (dx) {
if (idx < x_size) {
dx[idx] = dout[idx] - temp_dx;
dx[idx] = static_cast<T>(static_cast<MPType>(dout[idx]) - temp_dx);
}
}
if (dy) {
if (idx < y_size) {
dy[idx] = temp_dx;
dy[idx] = static_cast<T>(temp_dx);
}
}
}
......@@ -58,17 +61,18 @@ __global__ void LerpGradScalarKernelImpl(const T* weight,
const int out_size,
const int x_size,
const int y_size) {
T weight_scalar = weight[0];
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType weight_scalar = static_cast<MPType>(weight[0]);
CUDA_KERNEL_LOOP_TYPE(idx, out_size, int64_t) {
T temp_dx = weight_scalar * dout[idx];
MPType temp_dx = weight_scalar * static_cast<MPType>(dout[idx]);
if (dx) {
if (idx < x_size) {
dx[idx] = dout[idx] - temp_dx;
dx[idx] = static_cast<T>(static_cast<MPType>(dout[idx]) - temp_dx);
}
}
if (dy) {
if (idx < y_size) {
dy[idx] = temp_dx;
dy[idx] = static_cast<T>(temp_dx);
}
}
}
......@@ -270,5 +274,10 @@ void LerpGradKernel(const Context& ctx,
} // namespace phi
PD_REGISTER_KERNEL(
lerp_grad, GPU, ALL_LAYOUT, phi::LerpGradKernel, float, double) {}
PD_REGISTER_KERNEL(lerp_grad,
GPU,
ALL_LAYOUT,
phi::LerpGradKernel,
phi::dtype::float16,
float,
double) {}
......@@ -18,4 +18,10 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/lerp_kernel_impl.h"
PD_REGISTER_KERNEL(lerp, GPU, ALL_LAYOUT, phi::LerpKernel, float, double) {}
PD_REGISTER_KERNEL(lerp,
GPU,
ALL_LAYOUT,
phi::LerpKernel,
phi::dtype::float16,
float,
double) {}
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
......@@ -43,11 +44,14 @@ static void LerpFunction(const Context& ctx,
auto eigen_w = phi::EigenTensor<T, D>::From(weight, w_dims);
auto eigen_out = phi::EigenTensor<T, D>::From(*out);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
auto& place = *ctx.eigen_device();
eigen_out.device(place) =
eigen_x.broadcast(x_bcast_dims) +
eigen_w.broadcast(w_bcast_dims) *
(eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims));
(eigen_x.broadcast(x_bcast_dims).template cast<MPType>() +
eigen_w.broadcast(w_bcast_dims).template cast<MPType>() *
(eigen_y.broadcast(y_bcast_dims).template cast<MPType>() -
eigen_x.broadcast(x_bcast_dims).template cast<MPType>()))
.template cast<T>();
}
template <typename Context, typename T>
......@@ -64,8 +68,13 @@ static void LerpFunctionZero(const Context& ctx,
auto eigen_w = phi::EigenTensor<T, 1>::From(weight, dim);
auto eigen_out = phi::EigenTensor<T, 1>::From(*out, dim);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
auto& place = *ctx.eigen_device();
eigen_out.device(place) = eigen_x + eigen_w * (eigen_y - eigen_x);
eigen_out.device(place) =
(eigen_x.template cast<MPType>() +
eigen_w.template cast<MPType>() *
(eigen_y.template cast<MPType>() - eigen_x.template cast<MPType>()))
.template cast<T>();
}
template <typename T, typename Context>
......
......@@ -30,9 +30,11 @@ class TestLerp(OpTest):
self.python_api = paddle.lerp
self.init_dtype()
self.init_shape()
x = np.arange(1.0, 101.0).astype(self.dtype).reshape(self.shape)
y = np.full(100, 10.0).astype(self.dtype).reshape(self.shape)
w = np.asarray([0.5]).astype(self.dtype)
self.init_xyshape()
self.init_wshape()
x = np.arange(1.0, 101.0).astype(self.dtype).reshape(self.xshape)
y = np.full(100, 10.0).astype(self.dtype).reshape(self.yshape)
w = np.random.random(self.wshape).astype(self.dtype)
self.inputs = {'X': x, 'Y': y, 'Weight': w}
self.outputs = {'Out': x + w * (y - x)}
......@@ -42,6 +44,13 @@ class TestLerp(OpTest):
def init_shape(self):
self.shape = [100]
def init_xyshape(self):
self.xshape = self.shape
self.yshape = self.shape
def init_wshape(self):
self.wshape = [1]
def test_check_output(self):
self.check_output()
......@@ -74,30 +83,46 @@ class TestLerpWithDim6(TestLerp):
self.shape = [2, 1, 2, 5, 1, 5]
class TestLerpWithDim6Fp16(TestLerp):
def init_shape(self):
self.shape = [2, 1, 2, 5, 1, 5]
def init_dtype(self):
self.dtype = np.float16
class TestLerpWihFp16BroadXY(TestLerp):
def init_xyshape(self):
self.xshape = [2, 1, 2, 5, 5]
self.yshape = [2, 2, 1, 5, 5]
def init_dtype(self):
self.dtype = np.float16
class TestLerpWithFp16BroadWToXY(TestLerp):
def init_shape(self):
self.shape = [2, 2, 5, 5]
def init_wshape(self):
self.wshape = [5]
def init_dtype(self):
self.dtype = np.float16
class TestLerpBroadXY(TestLerp):
def setUp(self):
self.op_type = "lerp"
self.python_api = paddle.lerp
self.init_dtype()
self.init_shape()
x = np.arange(1.0, 201.0).astype(self.dtype).reshape([2, 1, 2, 50])
y = np.full(200, 10.0).astype(self.dtype).reshape([2, 2, 1, 50])
w = np.asarray([0.5]).astype(self.dtype)
self.inputs = {'X': x, 'Y': y, 'Weight': w}
self.outputs = {'Out': x + w * (y - x)}
def init_xyshape(self):
self.xshape = [2, 1, 2, 5, 5]
self.yshape = [2, 2, 1, 5, 5]
class TestLerpBroadWToXY(TestLerp):
def setUp(self):
self.op_type = "lerp"
self.python_api = paddle.lerp
self.init_dtype()
self.init_shape()
x = np.full(600, 2.5).astype(self.dtype).reshape([50, 2, 2, 3])
y = np.full(600, 1.0).astype(self.dtype).reshape([50, 2, 2, 3])
w = np.random.random([3]).astype(self.dtype)
self.inputs = {'X': x, 'Y': y, 'Weight': w}
self.outputs = {'Out': x + w * (y - x)}
def init_shape(self):
self.shape = [2, 2, 5, 5]
def init_wshape(self):
self.wshape = [5]
class TestLerpAPI(unittest.TestCase):
......
......@@ -4215,9 +4215,9 @@ def lerp(x, y, weight, name=None):
lerp(x, y, weight) = x + weight * (y - x).
Args:
x (Tensor): An N-D Tensor with starting points, the data type is float32, float64.
y (Tensor): An N-D Tensor with ending points, the data type is float32, float64.
weight (float|Tensor): The weight for the interpolation formula. When weight is Tensor, the data type is float32, float64.
x (Tensor): An N-D Tensor with starting points, the data type is float16, float32, float64.
y (Tensor): An N-D Tensor with ending points, the data type is float16, float32, float64.
weight (float|Tensor): The weight for the interpolation formula. When weight is Tensor, the data type is float16, float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......@@ -4241,10 +4241,14 @@ def lerp(x, y, weight, name=None):
if in_dygraph_mode():
return _C_ops.lerp(x, y, weight)
else:
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'lerp')
check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'lerp')
check_variable_and_dtype(
weight, 'weight', ['float32', 'float64'], 'lerp'
x, 'x', ['float16', 'float32', 'float64'], 'lerp'
)
check_variable_and_dtype(
y, 'y', ['float16', 'float32', 'float64'], 'lerp'
)
check_variable_and_dtype(
weight, 'weight', ['float16', 'float32', 'float64'], 'lerp'
)
helper = LayerHelper('lerp', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册