diff --git a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu index 8a1f31a5cfe37be9a3c9c53d5bd1d839d5b0d04a..43cf0deab6dd9dc445e8326e83627ea3b98fb12e 100644 --- a/paddle/phi/kernels/gpu/lerp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_grad_kernel.cu @@ -279,5 +279,6 @@ PD_REGISTER_KERNEL(lerp_grad, ALL_LAYOUT, phi::LerpGradKernel, phi::dtype::float16, + phi::dtype::bfloat16, float, double) {} diff --git a/paddle/phi/kernels/gpu/lerp_kernel.cu b/paddle/phi/kernels/gpu/lerp_kernel.cu index e93e2bb21924c020442e8f2add8e67acec1f5e2b..17964760990cc34cc99d214aaac34277aebd24ba 100644 --- a/paddle/phi/kernels/gpu/lerp_kernel.cu +++ b/paddle/phi/kernels/gpu/lerp_kernel.cu @@ -127,5 +127,6 @@ PD_REGISTER_KERNEL(lerp, ALL_LAYOUT, phi::LerpKernel, phi::dtype::float16, + phi::dtype::bfloat16, float, double) {} diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 8b5af17b86f23935f7a38b84e6836e192d25e313..9998ec79ef72598c9291649e6933e5863ed016b5 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4512,9 +4512,9 @@ def lerp(x, y, weight, name=None): lerp(x, y, weight) = x + weight * (y - x). Args: - x (Tensor): An N-D Tensor with starting points, the data type is float16, float32, float64. - y (Tensor): An N-D Tensor with ending points, the data type is float16, float32, float64. - weight (float|Tensor): The weight for the interpolation formula. When weight is Tensor, the data type is float16, float32, float64. + x (Tensor): An N-D Tensor with starting points, the data type is bfloat16, float16, float32, float64. + y (Tensor): An N-D Tensor with ending points, the data type is bfloat16, float16, float32, float64. + weight (float|Tensor): The weight for the interpolation formula. When weight is Tensor, the data type is bfloat16, float16, float32, float64. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -4539,13 +4539,16 @@ def lerp(x, y, weight, name=None): return _C_ops.lerp(x, y, weight) else: check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64'], 'lerp' + x, 'x', ['uint16', 'float16', 'float32', 'float64'], 'lerp' ) check_variable_and_dtype( - y, 'y', ['float16', 'float32', 'float64'], 'lerp' + y, 'y', ['uint16', 'float16', 'float32', 'float64'], 'lerp' ) check_variable_and_dtype( - weight, 'weight', ['float16', 'float32', 'float64'], 'lerp' + weight, + 'weight', + ['uint16', 'float16', 'float32', 'float64'], + 'lerp', ) helper = LayerHelper('lerp', **locals()) diff --git a/test/legacy_test/test_lerp_op.py b/test/legacy_test/test_lerp_op.py index 7bf2cb73380407ac96a2ab35d1c648cdb08181bd..7966e9a4b98f5c313fe91039836d3d512aa5b12f 100644 --- a/test/legacy_test/test_lerp_op.py +++ b/test/legacy_test/test_lerp_op.py @@ -15,7 +15,7 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle.fluid import core @@ -220,5 +220,65 @@ class TestLerpAPI(unittest.TestCase): paddle.enable_static() +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestLerpBF16(TestLerp): + def setUp(self): + self.op_type = "lerp" + self.python_api = paddle.lerp + self.dtype = np.uint16 + self.init_shape() + self.init_xyshape() + self.init_wshape() + x = np.arange(1.0, 101.0).astype("float32").reshape(self.xshape) + y = np.full(100, 10.0).astype("float32").reshape(self.yshape) + w = np.random.random(self.wshape).astype("float32") + self.init_grad(w) + self.inputs = { + 'X': convert_float_to_uint16(x), + 'Y': convert_float_to_uint16(y), + 'Weight': convert_float_to_uint16(w), + } + self.outputs = {'Out': convert_float_to_uint16(x + w * (y - x))} + + def init_shape(self): + self.shape = [100] + + def init_xyshape(self): + self.xshape = self.shape + self.yshape = self.shape + + def init_wshape(self): + self.wshape = [1] + + def init_grad(self, w): + self.x_grad = ( + np.ones(self.xshape) + * (1 - w) + / (np.prod(self.xshape) / np.prod(self.wshape)) + ) + self.y_grad = ( + np.ones(self.yshape) + * w + / (np.prod(self.yshape) / np.prod(self.wshape)) + ) + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, + ['X', 'Y'], + 'Out', + user_defined_grads=[self.x_grad, self.y_grad], + ) + + if __name__ == "__main__": unittest.main()