diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h index 35769d0e6e3386786f6e8983b27a19b849051846..e39c860bc809619a73a90efbc173c5a6052eee55 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.cu.h @@ -1316,6 +1316,74 @@ inline void Blas::GEMM(bool transA, }); } +template <> +template <> +inline void Blas::GEMM(bool transA, + bool transB, + int M, + int N, + int K, + phi::dtype::bfloat16 alpha, + const phi::dtype::bfloat16 *A, + int lda, + const phi::dtype::bfloat16 *B, + int ldb, + phi::dtype::bfloat16 beta, + phi::dtype::bfloat16 *C, + int ldc) const { +#if CUDA_VERSION >= 11000 + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cuTransB = transB ? CUBLAS_OP_T : CUBLAS_OP_N; + + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 80, + phi::errors::InvalidArgument( + "cublas bf16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT; + bool use_tensor_op_math = context_.tensor_core_available(); + if (use_tensor_op_math) { + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } + VLOG(5) << "use_tensor_op_math: " << (use_tensor_op_math ? "True" : "False"); + + context_.TensorCoreCublasCallIfAvailable([&](cublasHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasGemmEx(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + CUDA_R_16BF, + ldb, + A, + CUDA_R_16BF, + lda, + &h_beta, + C, + CUDA_R_16BF, + ldc, + CUDA_R_32F, + algo)); + }); +#else + // raise error + PADDLE_THROW(phi::errors::Unimplemented( + "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); + +#endif // CUDA_VERSION >= 11000 +} + template <> template void Blas::AXPY(int n, T alpha, const T *x, T *y) const { diff --git a/paddle/phi/kernels/funcs/blas/blas_impl.hip.h b/paddle/phi/kernels/funcs/blas/blas_impl.hip.h index 5edfe3a602c7bad978107957f02db5c562ebc545..bb02242e2db7218adc6f1d5904f26e48664e6014 100644 --- a/paddle/phi/kernels/funcs/blas/blas_impl.hip.h +++ b/paddle/phi/kernels/funcs/blas/blas_impl.hip.h @@ -751,7 +751,7 @@ inline void Blas::GEMM(CBLAS_TRANSPOSE transA, context_.GetComputeCapability(), 80, phi::errors::InvalidArgument( - "rocblas fp16 gemm requires GPU compute capability >= 80," + "rocblas bf16 gemm requires GPU compute capability >= 80," "but received %d", context_.GetComputeCapability())); @@ -982,6 +982,70 @@ inline void Blas::GEMM(bool transA, }); } +template <> +template <> +inline void Blas::GEMM(bool transA, + bool transB, + int M, + int N, + int K, + phi::dtype::bfloat16 alpha, + const phi::dtype::bfloat16 *A, + int lda, + const phi::dtype::bfloat16 *B, + int ldb, + phi::dtype::bfloat16 beta, + phi::dtype::bfloat16 *C, + int ldc) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_operation cuTransB = (transB == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + PADDLE_ENFORCE_GE( + context_.GetComputeCapability(), + 80, + phi::errors::InvalidArgument( + "rocblas bf16 gemm requires GPU compute capability >= 80," + "but received %d", + context_.GetComputeCapability())); + + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + + context_.TensorCoreCublasCallIfAvailable([&](rocblas_handle handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::rocblas_gemm_ex(handle, + cuTransB, + cuTransA, + N, + M, + K, + &h_alpha, + B, + rocblas_datatype_bf16_r, + ldb, + A, + rocblas_datatype_bf16_r, + lda, + &h_beta, + C, + rocblas_datatype_bf16_r, + ldc, + C, + rocblas_datatype_bf16_r, + ldc, + rocblas_datatype_f32_r, + algo, + 0, + 0)); + }); +} + template <> template void Blas::AXPY(int n, T alpha, const T *x, T *y) const { diff --git a/paddle/phi/kernels/gpu/addmm_grad_kernel.cu b/paddle/phi/kernels/gpu/addmm_grad_kernel.cu index 65978da1374e4888afe8a7b408b0bb5a70d92b66..9d915af9170f6deba0051e6a718d292f49b49feb 100644 --- a/paddle/phi/kernels/gpu/addmm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/addmm_grad_kernel.cu @@ -18,5 +18,11 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/addmm_grad_kernel_impl.h" -PD_REGISTER_KERNEL( - addmm_grad, GPU, ALL_LAYOUT, phi::AddmmGradKernel, float, double) {} +PD_REGISTER_KERNEL(addmm_grad, + GPU, + ALL_LAYOUT, + phi::AddmmGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/addmm_kernel.cu b/paddle/phi/kernels/gpu/addmm_kernel.cu index 7b589ce20acca5c6cf51fd16ea223ef6b0d17466..563b137040ac779c5fd796ecc079fa926e520570 100644 --- a/paddle/phi/kernels/gpu/addmm_kernel.cu +++ b/paddle/phi/kernels/gpu/addmm_kernel.cu @@ -18,4 +18,11 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/addmm_kernel_impl.h" -PD_REGISTER_KERNEL(addmm, GPU, ALL_LAYOUT, phi::AddmmKernel, float, double) {} +PD_REGISTER_KERNEL(addmm, + GPU, + ALL_LAYOUT, + phi::AddmmKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h b/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h index a72fb6062ceffe7547ad5b8c64e57bc02a4e78cd..7b05b08eef5fb1dbc1bcb452dac4c749c774ada2 100644 --- a/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/addmm_grad_kernel_impl.h @@ -18,13 +18,34 @@ limitations under the License. */ #include "glog/logging.h" +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/addmm_grad_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/for_range.h" namespace phi { +template +struct CopyOrScaleFunctor { + CopyOrScaleFunctor(const float scale, const T* x, T* output, int64_t numel) + : scale_(scale), x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + using MPType = typename phi::dtype::MPTypeTrait::Type; + const MPType mp_scale = static_cast(scale_); + const MPType mp_x = static_cast(x_[idx]); + output_[idx] = static_cast(mp_scale * mp_x); + } + + private: + const float scale_; + const T* x_; + T* output_; + int64_t numel_; +}; + template ::Type; + bool is_float16_or_bfloat16 = false; + if (std::is_same::value || + std::is_same::value) { + is_float16_or_bfloat16 = true; + } + auto in_dims = input.dims(); if (input.dims().size() == 1) { in_dims = {1, input.dims()[0]}; @@ -65,6 +93,7 @@ void AddmmGradKernel(const Context& dev_ctx, } auto blas = funcs::GetBlas(dev_ctx); + auto mt_blas = funcs::GetBlas(dev_ctx); if (input_grad) { dev_ctx.template Alloc(input_grad); total_elems = in_dims[0] * in_dims[1]; @@ -78,19 +107,60 @@ void AddmmGradKernel(const Context& dev_ctx, Array2(input_grad->dims()[0], input_grad->dims()[1]); if (row_compress && col_compress) { - eigen_dinput.device(place) = - eigen_dout.sum().eval().reshape(eigen_dinput_shape); + if (!is_float16_or_bfloat16) { + eigen_dinput.device(place) = + eigen_dout.sum().eval().reshape(eigen_dinput_shape); + } else { + eigen_dinput.device(place) = eigen_dout.template cast() + .sum() + .eval() + .reshape(eigen_dinput_shape) + .template cast(); + } } else if (row_compress) { - eigen_dinput.device(place) = - eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape); + if (!is_float16_or_bfloat16) { + eigen_dinput.device(place) = + eigen_dout.sum(Array1(0)).eval().reshape(eigen_dinput_shape); + } else { + eigen_dinput.device(place) = eigen_dout.template cast() + .sum(Array1(0)) + .eval() + .reshape(eigen_dinput_shape) + .template cast(); + } } else if (col_compress) { - eigen_dinput.device(place) = - eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape); + if (!is_float16_or_bfloat16) { + eigen_dinput.device(place) = + eigen_dout.sum(Array1(1)).eval().reshape(eigen_dinput_shape); + } else { + eigen_dinput.device(place) = eigen_dout.template cast() + .sum(Array1(1)) + .eval() + .reshape(eigen_dinput_shape) + .template cast(); + } } else { - blas.VCOPY(total_elems, out_grad.data(), input_grad->data()); + // The VCOPY does not support the float16, bfloat16 + if (!is_float16_or_bfloat16) { + mt_blas.VCOPY( + total_elems, out_grad.data(), input_grad->data()); + } else { + phi::funcs::ForRange for_range(dev_ctx, total_elems); + CopyOrScaleFunctor functor( + 1, out_grad.data(), input_grad->data(), total_elems); + for_range(functor); + } } - blas.SCAL(total_elems, beta, input_grad->data()); + // The SCAL does not support the float16, bfloat16 + if (!is_float16_or_bfloat16) { + mt_blas.SCAL(total_elems, beta, input_grad->data()); + } else { + phi::funcs::ForRange for_range(dev_ctx, total_elems); + CopyOrScaleFunctor functor( + beta, input_grad->data(), input_grad->data(), total_elems); + for_range(functor); + } if (input.dims().size() == 1) { input_grad->Resize(input.dims()); @@ -101,14 +171,28 @@ void AddmmGradKernel(const Context& dev_ctx, total_elems = x.dims()[0] * x.dims()[1]; // x_grad = out_grad * y'. x_grad: M x K, out_grad : M x N, y : K x N blas.MatMul(out_grad, false, y, true, x_grad); - blas.SCAL(total_elems, alpha, x_grad->data()); + if (!is_float16_or_bfloat16) { + mt_blas.SCAL(total_elems, alpha, x_grad->data()); + } else { + phi::funcs::ForRange for_range(dev_ctx, total_elems); + CopyOrScaleFunctor functor( + alpha, x_grad->data(), x_grad->data(), total_elems); + for_range(functor); + } } if (y_grad) { dev_ctx.template Alloc(y_grad); total_elems = x.dims()[1] * y.dims()[1]; // y_grad = x' * out_grad. y_grad K x N, out_grad : M x N, x : M x K blas.MatMul(x, true, out_grad, false, y_grad); - blas.SCAL(total_elems, alpha, y_grad->data()); + if (!is_float16_or_bfloat16) { + mt_blas.SCAL(total_elems, alpha, y_grad->data()); + } else { + phi::funcs::ForRange for_range(dev_ctx, total_elems); + CopyOrScaleFunctor functor( + alpha, y_grad->data(), y_grad->data(), total_elems); + for_range(functor); + } } } diff --git a/paddle/phi/kernels/impl/addmm_kernel_impl.h b/paddle/phi/kernels/impl/addmm_kernel_impl.h index c86cea80e47e8b4f24d72ac2e80ff3cc9dce2a49..957e02de6af62f732ff7b772059150e0ad78af4a 100644 --- a/paddle/phi/kernels/impl/addmm_kernel_impl.h +++ b/paddle/phi/kernels/impl/addmm_kernel_impl.h @@ -112,17 +112,19 @@ void AddmmKernel(const Context& dev_ctx, funcs::EigenBroadcast, T, 2>::Eval( place, eigen_out, eigen_input, bcast_dims); + T t_alpha = static_cast(alpha); + T t_beta = static_cast(beta); blas.GEMM(false, false, x_dims[0], y_dims[1], x_dims[1], - alpha, + t_alpha, x.data(), x_dims[1], y.data(), y_dims[1], - beta, + t_beta, out->data(), y_dims[1]); } diff --git a/python/paddle/fluid/tests/unittests/test_addmm_op.py b/python/paddle/fluid/tests/unittests/test_addmm_op.py index 3041841cdf84f165cb6479c3525a48568ff117f8..66a86961e885d8553f2722b52af124946a8de750 100644 --- a/python/paddle/fluid/tests/unittests/test_addmm_op.py +++ b/python/paddle/fluid/tests/unittests/test_addmm_op.py @@ -15,11 +15,11 @@ import unittest import numpy as np -from eager_op_test import OpTest +from eager_op_test import OpTest, convert_float_to_uint16 import paddle from paddle import fluid -from paddle.fluid import Program, program_guard +from paddle.fluid import Program, core, program_guard class TestAddMMOp(OpTest): @@ -27,7 +27,6 @@ class TestAddMMOp(OpTest): def setUp(self): self.op_type = "addmm" self.python_api = paddle.addmm - self.dtype = np.float64 self.init_dtype_type() self.inputs = { 'Input': np.random.random((100, 1)).astype(self.dtype), @@ -40,7 +39,7 @@ class TestAddMMOp(OpTest): } def init_dtype_type(self): - pass + self.dtype = np.float64 def test_check_output(self): self.check_output() @@ -58,6 +57,62 @@ class TestAddMMOp(OpTest): self.check_grad(['Input'], 'Out', no_grad_set=None) +class TestAddMMFP16Op(TestAddMMOp): + def init_dtype_type(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output(atol=1e-2) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestAddMMBF16Op(OpTest): + def setUp(self): + self.op_type = "addmm" + self.python_api = paddle.addmm + self.init_dtype_type() + self.inputs = { + 'Input': np.random.random((100, 1)).astype(self.np_dtype), + 'X': np.random.random((100, 10)).astype(self.np_dtype), + 'Y': np.random.random((10, 20)).astype(self.np_dtype), + } + self.outputs = { + 'Out': self.inputs['Input'] + + np.dot(self.inputs['X'], self.inputs['Y']) + } + + self.inputs['Input'] = convert_float_to_uint16(self.inputs['Input']) + self.inputs['X'] = convert_float_to_uint16(self.inputs['X']) + self.inputs['Y'] = convert_float_to_uint16(self.inputs['Y']) + self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out']) + self.place = core.CUDAPlace(0) + + def init_dtype_type(self): + self.dtype = np.uint16 + self.np_dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['Input', 'X', 'Y'], 'Out') + + def test_check_grad_x(self): + self.check_grad_with_place(self.place, ['X'], 'Out', no_grad_set=None) + + def test_check_grad_y(self): + self.check_grad_with_place(self.place, ['Y'], 'Out', no_grad_set=None) + + def test_check_grad_input(self): + self.check_grad_with_place( + self.place, ['Input'], 'Out', no_grad_set=None + ) + + class TestAddMMOpError(unittest.TestCase): # test error def test_errors(self): diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 37519996a2acd30347863b2114193a4508164f1f..4a257822e5b0dd00159f2265cd69851cf3d6fd00 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1959,10 +1959,14 @@ def addmm(input, x, y, beta=1.0, alpha=1.0, name=None): helper = LayerHelper("addmm", **locals()) check_variable_and_dtype( - input, 'Input', ['float32', 'float64'], 'addmm' + input, 'Input', ['float16', 'float32', 'float64', 'uint16'], 'addmm' + ) + check_variable_and_dtype( + x, 'X', ['float16', 'float32', 'float64', 'uint16'], 'addmm' + ) + check_variable_and_dtype( + y, 'Y', ['float16', 'float32', 'float64', 'uint16'], 'addmm' ) - check_variable_and_dtype(x, 'X', ['float32', 'float64'], 'addmm') - check_variable_and_dtype(y, 'Y', ['float32', 'float64'], 'addmm') out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op(