diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 0b9b852a7585d02308179c1d00b889ac508969ee..055991af1dd9f67ec5d2e2f1824dba59c9c5c35f 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -905,14 +905,16 @@ template class ReduceOp, typename TransformOp> -static typename std::enable_if::value, - void>::type -CubTensorReduceImpl(const Tx* x_data, - Ty* y_data, - const TransformOp& transform, - int reduce_num, - const KPDevice& dev_ctx, - KPStream stream) { +static + typename std::enable_if::value && + !std::is_same::value, + void>::type + CubTensorReduceImpl(const Tx* x_data, + Ty* y_data, + const TransformOp& transform, + int reduce_num, + const KPDevice& dev_ctx, + KPStream stream) { auto reducer = ReduceOp(); cub::TransformInputIterator trans_x(x_data, transform); @@ -956,6 +958,23 @@ CubTensorReduceImpl(const Tx* x_data, PADDLE_THROW(phi::errors::InvalidArgument( "Tx should not be float16 when using cub::DeviceReduce::Reduce().")); } + +template + class ReduceOp, + typename TransformOp> +static typename std::enable_if::value, + void>::type +CubTensorReduceImpl(const Tx* x_data, + Ty* y_data, + const TransformOp& transform, + int reduce_num, + const KPDevice& dev_ctx, + KPStream stream) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Tx should not be bfloat16 when using cub::DeviceReduce::Reduce().")); +} #endif // PADDLE_WITH_XPU_KP template ::value; - bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16; + constexpr bool kIsTxBF16 = std::is_same::value; + bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16 && !kIsTxBF16; #ifndef PADDLE_WITH_XPU_KP if (use_cub_reduce) { if (is_mean) { diff --git a/paddle/phi/kernels/funcs/uniform_real_distribution.h b/paddle/phi/kernels/funcs/uniform_real_distribution.h index 07318d4b6df15a1641c0eefc23c0795959adcea0..e24ebbd230ebd8503009a789b56799ec8877b2c3 100644 --- a/paddle/phi/kernels/funcs/uniform_real_distribution.h +++ b/paddle/phi/kernels/funcs/uniform_real_distribution.h @@ -46,4 +46,16 @@ inline void UniformRealDistribution(phi::dtype::bfloat16 *data, } } +template <> +inline void UniformRealDistribution(phi::dtype::float16 *data, + const int64_t &size, + const float &min, + const float &max, + std::shared_ptr engine) { + std::uniform_real_distribution dist(min, max); + for (int64_t i = 0; i < size; ++i) { + data[i] = static_cast(dist(*engine)); + } +} + } // namespace phi diff --git a/paddle/phi/kernels/gpu/uniform_kernel.cu b/paddle/phi/kernels/gpu/uniform_kernel.cu index 277dadabea6d2e50d0e59f7051ce3a380a9e68e1..fe36fe5fc6e1ad70d6e39b02dd431ebdb8e5491e 100644 --- a/paddle/phi/kernels/gpu/uniform_kernel.cu +++ b/paddle/phi/kernels/gpu/uniform_kernel.cu @@ -92,4 +92,5 @@ PD_REGISTER_KERNEL(uniform_raw, phi::UniformRawKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/primitive/compute_primitives.h b/paddle/phi/kernels/primitive/compute_primitives.h index 1dfcde4e5dd0e1673bbdff720bc4ee068ac5abfe..24b961abb9ba47290b01f5766b015e178f7b449e 100644 --- a/paddle/phi/kernels/primitive/compute_primitives.h +++ b/paddle/phi/kernels/primitive/compute_primitives.h @@ -52,6 +52,12 @@ class MPTypeTrait { using Type = float; }; +template <> +class MPTypeTrait { + public: + using Type = float; +}; + /** * @brief Will be used in BlockYReduce, get the index of reduce_num in shared * memory. diff --git a/paddle/phi/kernels/selected_rows/uniform_kernel.cc b/paddle/phi/kernels/selected_rows/uniform_kernel.cc index 73d00aa9a796e313f75cd803eed69d7f92725d69..89707018002a62a38dc378eacebce0e18abe10ea 100644 --- a/paddle/phi/kernels/selected_rows/uniform_kernel.cc +++ b/paddle/phi/kernels/selected_rows/uniform_kernel.cc @@ -78,12 +78,23 @@ PD_REGISTER_KERNEL(uniform_sr, #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL( - uniform_raw_sr, GPU, ALL_LAYOUT, phi::sr::UniformRawKernel, float, double) { -} +PD_REGISTER_KERNEL(uniform_raw_sr, + GPU, + ALL_LAYOUT, + phi::sr::UniformRawKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} -PD_REGISTER_KERNEL( - uniform_sr, GPU, ALL_LAYOUT, phi::sr::UniformKernel, float, double) {} +PD_REGISTER_KERNEL(uniform_sr, + GPU, + ALL_LAYOUT, + phi::sr::UniformKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} #endif #if defined(PADDLE_WITH_XPU) diff --git a/paddle/phi/kernels/uniform_kernel.cc b/paddle/phi/kernels/uniform_kernel.cc index 3744fc49d77a01ddc4908da4399d78c14f481f34..7e8138f6e1d63965b80a142c808edabf15ccea4c 100644 --- a/paddle/phi/kernels/uniform_kernel.cc +++ b/paddle/phi/kernels/uniform_kernel.cc @@ -56,7 +56,8 @@ PD_REGISTER_KERNEL(uniform, phi::UniformKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} #endif #ifdef PADDLE_WITH_XPU diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py index 407d70b4dadf3aa7effbabb99e769b904db9f957..3965425c7820d695905474a9672b0b7e1cd3f7b8 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py @@ -16,18 +16,21 @@ import os import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_uint16_to_float from test_attribute_var import UnittestBase import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid import Program, program_guard +from paddle.fluid.framework import convert_np_dtype_to_dtype_ from paddle.fluid.op import Operator from paddle.tensor import random def output_hist(out): + if out.dtype == np.uint16: + out = convert_uint16_to_float(out) hist, _ = np.histogram(out, range=(-5, 10)) hist = hist.astype("float32") hist /= float(out.size) @@ -151,15 +154,19 @@ class TestUniformRandomOp(OpTest): self.op_type = "uniform_random" self.python_api = paddle.uniform self.inputs = {} + self.init_dtype() self.init_attrs() self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} + def init_dtype(self): + self.dtype = np.float32 + def init_attrs(self): self.attrs = { "shape": [1000, 784], "min": -5.0, "max": 10.0, - "seed": 10, + "dtype": convert_np_dtype_to_dtype_(self.dtype), } self.output_hist = output_hist @@ -176,13 +183,25 @@ class TestUniformRandomOp(OpTest): with fluid.dygraph.base.guard(place=place): out = self.python_api( self.attrs['shape'], - 'float32', + self.dtype, self.attrs['min'], self.attrs['max'], - self.attrs['seed'], ) +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) +class TestUniformRandomFP16Op(TestUniformRandomOp): + def init_dtype(self): + self.dtype = np.float16 + + +class TestUniformRandomBF16Op(TestUniformRandomOp): + def init_dtype(self): + self.dtype = np.uint16 + + class TestUniformRandomOpError(unittest.TestCase): def test_errors(self): main_prog = Program()