From 6fc74bbaf614cf8501a812b7044191df8f21117d Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Fri, 25 Sep 2020 15:18:35 +0800 Subject: [PATCH] add fp16 for matmul (#27523) * add fp16 for matmul --- paddle/fluid/operators/math/blas_impl.cu.h | 29 ++++++ paddle/fluid/operators/matmul_v2_op.cu | 10 +- paddle/fluid/operators/matmul_v2_op.h | 55 ++++++----- .../tests/unittests/test_matmul_v2_op.py | 99 ++++++++++++++----- python/paddle/tensor/linalg.py | 4 +- 5 files changed, 142 insertions(+), 55 deletions(-) diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index a0464cf70e2..aeafe22235c 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -420,6 +420,22 @@ void Blas::GEMV(bool trans_a, int M, int N, }); } +template <> +template <> +inline void Blas::GEMV( + bool trans_a, int M, int N, platform::float16 alpha, + const platform::float16 *A, const platform::float16 *B, + platform::float16 beta, platform::float16 *C) const { + // Because cublas doesn't support half gemv, we use cublasHgemm to achieve it. + if (trans_a) { + this->template GEMM(CblasNoTrans, CblasNoTrans, 1, N, M, + alpha, B, A, beta, C); + } else { + this->template GEMM(CblasNoTrans, CblasNoTrans, M, 1, N, + alpha, A, B, beta, C); + } +} + template <> template void Blas::BatchedGEMM( @@ -479,6 +495,19 @@ void Blas::BatchedGEMM( } } +template <> +template <> +inline void Blas::BatchedGEMM( + CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K, + platform::float16 alpha, const platform::float16 **A, + const platform::float16 **B, platform::float16 beta, platform::float16 **C, + int batchCount) const { + for (int k = 0; k < batchCount; ++k) { + this->template GEMM(transA, transB, M, N, K, alpha, A[k], + B[k], beta, C[k]); + } +} + template <> template void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, diff --git a/paddle/fluid/operators/matmul_v2_op.cu b/paddle/fluid/operators/matmul_v2_op.cu index 64ec65a2341..91958513ddb 100644 --- a/paddle/fluid/operators/matmul_v2_op.cu +++ b/paddle/fluid/operators/matmul_v2_op.cu @@ -17,10 +17,12 @@ limitations under the License. */ namespace ops = paddle::operators; namespace plf = paddle::platform; -REGISTER_OP_CUDA_KERNEL(matmul_v2, - ops::MatMulV2Kernel, - ops::MatMulV2Kernel); +REGISTER_OP_CUDA_KERNEL( + matmul_v2, ops::MatMulV2Kernel, + ops::MatMulV2Kernel, + ops::MatMulV2Kernel); REGISTER_OP_CUDA_KERNEL( matmul_v2_grad, ops::MatMulV2GradKernel, - ops::MatMulV2GradKernel); + ops::MatMulV2GradKernel, + ops::MatMulV2GradKernel); diff --git a/paddle/fluid/operators/matmul_v2_op.h b/paddle/fluid/operators/matmul_v2_op.h index 8cd4fa12be4..ee485bd1711 100644 --- a/paddle/fluid/operators/matmul_v2_op.h +++ b/paddle/fluid/operators/matmul_v2_op.h @@ -163,17 +163,20 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, if (trans_y) { const int M = Y->numel() / N; VLOG(3) << "MatMul's case 2"; - blas.GEMV(false, M, N, 1., y_data, x_data, 0., Out->data()); + blas.GEMV(false, M, N, static_cast(1), y_data, x_data, + static_cast(0), Out->data()); } else { const int M = y_dims[y_ndim - 1]; const int batch_size = Y->numel() / (M * N); if (batch_size == 1) { VLOG(3) << "MatMul's case 3"; - blas.GEMV(true, N, M, 1., y_data, x_data, 0., Out->data()); + blas.GEMV(true, N, M, static_cast(1), y_data, x_data, + static_cast(0), Out->data()); } else { VLOG(3) << "MatMul's case 4"; - blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, 1.0f, y_data, - x_data, 0, Out->data(), batch_size, M * N, 0); + blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast(1), + y_data, x_data, static_cast(0), Out->data(), + batch_size, M * N, 0); } } return; @@ -205,16 +208,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, const int batch_size = X->numel() / (M * N); if (batch_size == 1) { VLOG(3) << "MatMul's case 5"; - blas.GEMV(true, N, M, 1.0f, x_data, y_data, 0.0f, Out->data()); + blas.GEMV(true, N, M, static_cast(1), x_data, y_data, + static_cast(0), Out->data()); } else { VLOG(3) << "MatMul's case 6"; - blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, 1.0f, x_data, - y_data, 0, Out->data(), batch_size, M * N, 0); + blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast(1), + x_data, y_data, static_cast(0), Out->data(), + batch_size, M * N, 0); } } else { const int M = X->numel() / N; VLOG(3) << "MatMul's case 7"; - blas.GEMV(false, M, N, 1.0f, x_data, y_data, 0.0f, Out->data()); + blas.GEMV(false, M, N, static_cast(1), x_data, y_data, + static_cast(0), Out->data()); } return; } @@ -263,37 +269,38 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, if (x_batch_size == 1 && y_batch_size == 1) { VLOG(3) << "MatMul's case 8"; blas.GEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data, - y_data, 0.0f, Out->data()); + trans_y ? CblasTrans : CblasNoTrans, M, N, K, static_cast(1), + x_data, y_data, static_cast(0), Out->data()); } else if (x_batch_size == 1) { if (M == 1 && trans_y) { VLOG(3) << "MatMul's case 9"; - blas.GEMV(false, y_batch_size * N, K, 1.0f, y_data, x_data, 0.0f, - Out->data()); + blas.GEMV(false, y_batch_size * N, K, static_cast(1), y_data, x_data, + static_cast(0), Out->data()); } else { VLOG(3) << "MatMul's case 10"; blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, - x_data, y_data, 0, Out->data(), out_batch_size, 0, - K * N); + trans_y ? CblasTrans : CblasNoTrans, M, N, K, + static_cast(1), x_data, y_data, static_cast(0), + Out->data(), out_batch_size, 0, K * N); } } else if (y_batch_size == 1) { if (!trans_x) { VLOG(3) << "MatMul's case 11"; blas.GEMM(CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans, - x_batch_size * M, N, K, 1.0f, x_data, y_data, 0.0f, - Out->data()); + x_batch_size * M, N, K, static_cast(1), x_data, y_data, + static_cast(0), Out->data()); } else { VLOG(3) << "MatMul's case 12"; blas.BatchedGEMM(CblasTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K, - 1.0f, x_data, y_data, 0, Out->data(), out_batch_size, - M * K, 0); + static_cast(1), x_data, y_data, static_cast(0), + Out->data(), out_batch_size, M * K, 0); } } else if (!is_broadcast_dims) { VLOG(3) << "MatMul's case 13"; blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data, - y_data, 0, Out->data(), out_batch_size, M * K, K * N); + trans_y ? CblasTrans : CblasNoTrans, M, N, K, + static_cast(1), x_data, y_data, static_cast(0), + Out->data(), out_batch_size, M * K, K * N); } else { // in the case, can't use stridedgemm std::vector x_ptr(out_batch_size); @@ -314,9 +321,9 @@ void MatMulFunction(const Tensor* X, const Tensor* Y, } VLOG(3) << "MatMul's case 14"; blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans, - trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, - x_ptr.data(), y_ptr.data(), 0.0f, out_ptr.data(), - out_batch_size); + trans_y ? CblasTrans : CblasNoTrans, M, N, K, + static_cast(1), x_ptr.data(), y_ptr.data(), + static_cast(0), out_ptr.data(), out_batch_size); } } diff --git a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py index 884139a23d5..640771df23b 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_v2_op.py @@ -65,15 +65,21 @@ class TestMatMulV2Op(OpTest): self.y_shape = (100, ) self.trans_x = False self.trans_y = False + + def init_kernel_type(self): self.dtype = "float64" def setUp(self): + self.init_kernel_type() self.config() self.op_type = "matmul_v2" x = np.random.random(self.x_shape).astype(self.dtype) y = np.random.random(self.y_shape).astype(self.dtype) + # -0.1 ~ 0.1 + x = -0.1 + 0.2 * x + y = -0.1 + 0.2 * y result = reference_matmul(x, y, self.trans_x, self.trans_y) - + result = result.astype(self.dtype) self.inputs = { 'X': x, 'Y': y, @@ -98,7 +104,6 @@ class TestMatMuklOp2(TestMatMulV2Op): self.y_shape = (1, 3, 2, 100) self.trans_x = False self.trans_y = True - self.dtype = "float64" class TestMatMuklOp3(TestMatMulV2Op): @@ -111,7 +116,6 @@ class TestMatMuklOp3(TestMatMulV2Op): self.y_shape = (1, 1, 100, 2) self.trans_x = False self.trans_y = False - self.dtype = "float64" class TestMatMuklOp4(TestMatMulV2Op): @@ -124,7 +128,6 @@ class TestMatMuklOp4(TestMatMulV2Op): self.y_shape = (1, 2, 100, 2) self.trans_x = False self.trans_y = False - self.dtype = "float64" class TestMatMuklOp5(TestMatMulV2Op): @@ -133,11 +136,10 @@ class TestMatMuklOp5(TestMatMulV2Op): """ def config(self): - self.x_shape = (1, 1, 100, 2) + self.x_shape = (1, 1, 100, 1) self.y_shape = (100, ) self.trans_x = True self.trans_y = False - self.dtype = "float64" class TestMatMuklOp6(TestMatMulV2Op): @@ -150,7 +152,6 @@ class TestMatMuklOp6(TestMatMulV2Op): self.y_shape = (100, ) self.trans_x = True self.trans_y = False - self.dtype = "float64" class TestMatMuklOp7(TestMatMulV2Op): @@ -163,7 +164,6 @@ class TestMatMuklOp7(TestMatMulV2Op): self.y_shape = (100, ) self.trans_x = False self.trans_y = False - self.dtype = "float64" class TestMatMuklOp8(TestMatMulV2Op): @@ -176,7 +176,6 @@ class TestMatMuklOp8(TestMatMulV2Op): self.y_shape = (1, 1, 100, 2) self.trans_x = False self.trans_y = False - self.dtype = "float64" class TestMatMuklOp9(TestMatMulV2Op): @@ -189,7 +188,6 @@ class TestMatMuklOp9(TestMatMulV2Op): self.y_shape = (2, 1, 2, 100) self.trans_x = False self.trans_y = True - self.dtype = "float64" class TestMatMuklOp10(TestMatMulV2Op): @@ -198,11 +196,10 @@ class TestMatMuklOp10(TestMatMulV2Op): """ def config(self): - self.x_shape = (1, 1, 2, 100) - self.y_shape = (1, 2, 100, 2) + self.x_shape = (1, 1, 25, 4) + self.y_shape = (1, 2, 4, 25) self.trans_x = False self.trans_y = False - self.dtype = "float64" class TestMatMuklOp11(TestMatMulV2Op): @@ -215,7 +212,6 @@ class TestMatMuklOp11(TestMatMulV2Op): self.y_shape = (1, 1, 100, 2) self.trans_x = False self.trans_y = False - self.dtype = "float64" class TestMatMuklOp12(TestMatMulV2Op): @@ -224,11 +220,10 @@ class TestMatMuklOp12(TestMatMulV2Op): """ def config(self): - self.x_shape = (2, 1, 100, 2) - self.y_shape = (1, 1, 100, 2) + self.x_shape = (2, 1, 4, 25) + self.y_shape = (1, 1, 4, 25) self.trans_x = True self.trans_y = False - self.dtype = "float64" class TestMatMuklOp13(TestMatMulV2Op): @@ -237,11 +232,10 @@ class TestMatMuklOp13(TestMatMulV2Op): """ def config(self): - self.x_shape = (2, 2, 100, 2) - self.y_shape = (2, 2, 100, 2) + self.x_shape = (2, 2, 2, 50) + self.y_shape = (2, 2, 2, 50) self.trans_x = True self.trans_y = False - self.dtype = "float64" class TestMatMuklOp14(TestMatMulV2Op): @@ -254,7 +248,6 @@ class TestMatMuklOp14(TestMatMulV2Op): self.y_shape = (1, 2, 2, 100, 2) self.trans_x = True self.trans_y = False - self.dtype = "float64" class TestMatMuklOp15(TestMatMulV2Op): @@ -267,7 +260,6 @@ class TestMatMuklOp15(TestMatMulV2Op): self.y_shape = (1, 2, 2, 100, 1) self.trans_x = False self.trans_y = False - self.dtype = "float64" class TestMatMuklOp16(TestMatMulV2Op): @@ -277,10 +269,9 @@ class TestMatMuklOp16(TestMatMulV2Op): def config(self): self.x_shape = (100) - self.y_shape = (1, 2, 2, 100, 1) + self.y_shape = (1, 2, 2, 100, 2) self.trans_x = False self.trans_y = False - self.dtype = "float64" class TestMatMuklOp17(TestMatMulV2Op): @@ -293,7 +284,54 @@ class TestMatMuklOp17(TestMatMulV2Op): self.y_shape = (100) self.trans_x = False self.trans_y = False - self.dtype = "float64" + + +#--------------------test matmul fp16-------------------- + + +def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0): + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + class TestMatMulOpFp16Case(parent): + def init_kernel_type(self): + self.dtype = np.float16 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_output_with_place(place, atol=atol) + + def test_check_grad(self): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + self.check_grad_with_place( + place, ['X', 'Y'], + 'Out', + max_relative_error=max_relative_error) + + cls_name = "{0}_{1}".format(parent.__name__, "Fp16") + TestMatMulOpFp16Case.__name__ = cls_name + globals()[cls_name] = TestMatMulOpFp16Case + + +create_test_fp16_class(TestMatMulV2Op) +create_test_fp16_class(TestMatMuklOp2) +create_test_fp16_class(TestMatMuklOp3) +create_test_fp16_class(TestMatMuklOp4) +create_test_fp16_class(TestMatMuklOp5) +create_test_fp16_class(TestMatMuklOp6) +create_test_fp16_class(TestMatMuklOp7) +create_test_fp16_class(TestMatMuklOp8) +create_test_fp16_class(TestMatMuklOp9) +create_test_fp16_class(TestMatMuklOp10) +create_test_fp16_class(TestMatMuklOp11) +create_test_fp16_class(TestMatMuklOp12) +create_test_fp16_class(TestMatMuklOp13) +create_test_fp16_class(TestMatMuklOp14) +create_test_fp16_class(TestMatMuklOp15) +create_test_fp16_class(TestMatMuklOp16) +create_test_fp16_class(TestMatMuklOp17) class TestMatMulV2API(unittest.TestCase): @@ -331,6 +369,17 @@ class TestMatMulV2API(unittest.TestCase): y = paddle.to_tensor(input_y) result = paddle.matmul(x, y) + def test_dygraph_fp16(self): + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + with fluid.dygraph.guard(place): + input_x = np.random.random([4, 3]).astype("float16") + input_y = np.random.random([3, 4]).astype("float16") + x = paddle.to_tensor(input_x) + y = paddle.to_tensor(input_y) + result = paddle.matmul(x, y) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index f27cfba487d..26624d3b5ff 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -156,8 +156,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None): def __check_input(x, y): var_names = {'x': x, 'y': y} for name, val in var_names.items(): - check_variable_and_dtype(val, name, ['float32', 'float64'], - 'matmul') + check_variable_and_dtype( + val, name, ['float16', 'float32', 'float64'], 'matmul') __check_input(x, y) -- GitLab