diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index f42f316aae98038179282cbf9ade4844f3065868..8a1f31a5cfe37be9a3c9c53d5bd1d839d5b0d04a 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -18,6 +18,7 @@ #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/broadcast_tensors_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/common_shape.h" @@ -35,16 +36,18 @@ __global__ void LerpGradKernelImpl(const T* weight, const int out_size, const int x_size, const int y_size) { + using MPType = typename phi::dtype::MPTypeTrait::Type; CUDA_KERNEL_LOOP_TYPE(idx, out_size, int64_t) { - T temp_dx = weight[idx] * dout[idx]; + MPType temp_dx = + static_cast(weight[idx]) * static_cast(dout[idx]); if (dx) { if (idx < x_size) { - dx[idx] = dout[idx] - temp_dx; + dx[idx] = static_cast(static_cast(dout[idx]) - temp_dx); } } if (dy) { if (idx < y_size) { - dy[idx] = temp_dx; + dy[idx] = static_cast(temp_dx); } } } @@ -58,17 +61,18 @@ __global__ void LerpGradScalarKernelImpl(const T* weight, const int out_size, const int x_size, const int y_size) { - T weight_scalar = weight[0]; + using MPType = typename phi::dtype::MPTypeTrait::Type; + MPType weight_scalar = static_cast(weight[0]); CUDA_KERNEL_LOOP_TYPE(idx, out_size, int64_t) { - T temp_dx = weight_scalar * dout[idx]; + MPType temp_dx = weight_scalar * static_cast(dout[idx]); if (dx) { if (idx < x_size) { - dx[idx] = dout[idx] - temp_dx; + dx[idx] = static_cast(static_cast(dout[idx]) - temp_dx); } } if (dy) { if (idx < y_size) { - dy[idx] = temp_dx; + dy[idx] = static_cast(temp_dx); } } } @@ -270,5 +274,10 @@ void LerpGradKernel(const Context& ctx, } // namespace phi -PD_REGISTER_KERNEL( - lerp_grad, GPU, ALL_LAYOUT, phi::LerpGradKernel, float, double) {} +PD_REGISTER_KERNEL(lerp_grad, + GPU, + ALL_LAYOUT, + phi::LerpGradKernel, + phi::dtype::float16, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/lerp_kernel.cu b/paddle/phi/kernels/gpu/lerp_kernel.cu index 3f6862ff9795e2ae3bd78fdf0081edb88022b2e8..25f37bb170476ba9acccc57270e20bd223e9f891 100644 --- a/paddle/phi/kernels/gpu/lerp_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_kernel.cu @@ -18,4 +18,10 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/lerp_kernel_impl.h" -PD_REGISTER_KERNEL(lerp, GPU, ALL_LAYOUT, phi::LerpKernel, float, double) {} +PD_REGISTER_KERNEL(lerp, + GPU, + ALL_LAYOUT, + phi::LerpKernel, + phi::dtype::float16, + float, + double) {} diff --git a/paddle/phi/kernels/impl/lerp_kernel_impl.h b/paddle/phi/kernels/impl/lerp_kernel_impl.h index 668349e09b951e46b83ed9dd12343a537878a22a..ad41b4e26367ad2216e35c1ab8290696b09e039f 100644 --- a/paddle/phi/kernels/impl/lerp_kernel_impl.h +++ b/paddle/phi/kernels/impl/lerp_kernel_impl.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/eigen/common.h" @@ -43,11 +44,14 @@ static void LerpFunction(const Context& ctx, auto eigen_w = phi::EigenTensor::From(weight, w_dims); auto eigen_out = phi::EigenTensor::From(*out); + using MPType = typename phi::dtype::MPTypeTrait::Type; auto& place = *ctx.eigen_device(); eigen_out.device(place) = - eigen_x.broadcast(x_bcast_dims) + - eigen_w.broadcast(w_bcast_dims) * - (eigen_y.broadcast(y_bcast_dims) - eigen_x.broadcast(x_bcast_dims)); + (eigen_x.broadcast(x_bcast_dims).template cast() + + eigen_w.broadcast(w_bcast_dims).template cast() * + (eigen_y.broadcast(y_bcast_dims).template cast() - + eigen_x.broadcast(x_bcast_dims).template cast())) + .template cast(); } template @@ -64,8 +68,13 @@ static void LerpFunctionZero(const Context& ctx, auto eigen_w = phi::EigenTensor::From(weight, dim); auto eigen_out = phi::EigenTensor::From(*out, dim); + using MPType = typename phi::dtype::MPTypeTrait::Type; auto& place = *ctx.eigen_device(); - eigen_out.device(place) = eigen_x + eigen_w * (eigen_y - eigen_x); + eigen_out.device(place) = + (eigen_x.template cast() + + eigen_w.template cast() * + (eigen_y.template cast() - eigen_x.template cast())) + .template cast(); } template diff --git a/python/paddle/fluid/tests/unittests/test_lerp_op.py b/python/paddle/fluid/tests/unittests/test_lerp_op.py index d2288d69d999ec0845d705bee310898627bda8f2..7bf2cb73380407ac96a2ab35d1c648cdb08181bd 100644 --- a/python/paddle/fluid/tests/unittests/test_lerp_op.py +++ b/python/paddle/fluid/tests/unittests/test_lerp_op.py @@ -30,9 +30,11 @@ class TestLerp(OpTest): self.python_api = paddle.lerp self.init_dtype() self.init_shape() - x = np.arange(1.0, 101.0).astype(self.dtype).reshape(self.shape) - y = np.full(100, 10.0).astype(self.dtype).reshape(self.shape) - w = np.asarray([0.5]).astype(self.dtype) + self.init_xyshape() + self.init_wshape() + x = np.arange(1.0, 101.0).astype(self.dtype).reshape(self.xshape) + y = np.full(100, 10.0).astype(self.dtype).reshape(self.yshape) + w = np.random.random(self.wshape).astype(self.dtype) self.inputs = {'X': x, 'Y': y, 'Weight': w} self.outputs = {'Out': x + w * (y - x)} @@ -42,6 +44,13 @@ class TestLerp(OpTest): def init_shape(self): self.shape = [100] + def init_xyshape(self): + self.xshape = self.shape + self.yshape = self.shape + + def init_wshape(self): + self.wshape = [1] + def test_check_output(self): self.check_output() @@ -74,30 +83,46 @@ class TestLerpWithDim6(TestLerp): self.shape = [2, 1, 2, 5, 1, 5] +class TestLerpWithDim6Fp16(TestLerp): + def init_shape(self): + self.shape = [2, 1, 2, 5, 1, 5] + + def init_dtype(self): + self.dtype = np.float16 + + +class TestLerpWihFp16BroadXY(TestLerp): + def init_xyshape(self): + self.xshape = [2, 1, 2, 5, 5] + self.yshape = [2, 2, 1, 5, 5] + + def init_dtype(self): + self.dtype = np.float16 + + +class TestLerpWithFp16BroadWToXY(TestLerp): + def init_shape(self): + self.shape = [2, 2, 5, 5] + + def init_wshape(self): + self.wshape = [5] + + def init_dtype(self): + self.dtype = np.float16 + + class TestLerpBroadXY(TestLerp): - def setUp(self): - self.op_type = "lerp" - self.python_api = paddle.lerp - self.init_dtype() - self.init_shape() - x = np.arange(1.0, 201.0).astype(self.dtype).reshape([2, 1, 2, 50]) - y = np.full(200, 10.0).astype(self.dtype).reshape([2, 2, 1, 50]) - w = np.asarray([0.5]).astype(self.dtype) - self.inputs = {'X': x, 'Y': y, 'Weight': w} - self.outputs = {'Out': x + w * (y - x)} + def init_xyshape(self): + self.xshape = [2, 1, 2, 5, 5] + self.yshape = [2, 2, 1, 5, 5] class TestLerpBroadWToXY(TestLerp): - def setUp(self): - self.op_type = "lerp" - self.python_api = paddle.lerp - self.init_dtype() - self.init_shape() - x = np.full(600, 2.5).astype(self.dtype).reshape([50, 2, 2, 3]) - y = np.full(600, 1.0).astype(self.dtype).reshape([50, 2, 2, 3]) - w = np.random.random([3]).astype(self.dtype) - self.inputs = {'X': x, 'Y': y, 'Weight': w} - self.outputs = {'Out': x + w * (y - x)} + def init_shape(self): + self.shape = [2, 2, 5, 5] + + def init_wshape(self): + self.wshape = [5] class TestLerpAPI(unittest.TestCase): diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 6676d4dc604c2d8236f710ece36f3424b5c0bafa..cc662d83457fa268e12dab7c06ce0d4513bc1499 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4215,9 +4215,9 @@ def lerp(x, y, weight, name=None): lerp(x, y, weight) = x + weight * (y - x). Args: - x (Tensor): An N-D Tensor with starting points, the data type is float32, float64. - y (Tensor): An N-D Tensor with ending points, the data type is float32, float64. - weight (float|Tensor): The weight for the interpolation formula. When weight is Tensor, the data type is float32, float64. + x (Tensor): An N-D Tensor with starting points, the data type is float16, float32, float64. + y (Tensor): An N-D Tensor with ending points, the data type is float16, float32, float64. + weight (float|Tensor): The weight for the interpolation formula. When weight is Tensor, the data type is float16, float32, float64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -4241,10 +4241,14 @@ def lerp(x, y, weight, name=None): if in_dygraph_mode(): return _C_ops.lerp(x, y, weight) else: - check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'lerp') - check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'lerp') check_variable_and_dtype( - weight, 'weight', ['float32', 'float64'], 'lerp' + x, 'x', ['float16', 'float32', 'float64'], 'lerp' + ) + check_variable_and_dtype( + y, 'y', ['float16', 'float32', 'float64'], 'lerp' + ) + check_variable_and_dtype( + weight, 'weight', ['float16', 'float32', 'float64'], 'lerp' ) helper = LayerHelper('lerp', **locals())