未验证 提交 6fc74bba 编写于 作者: S ShenLiang 提交者: GitHub

add fp16 for matmul (#27523)

* add fp16 for matmul
上级 fab4e6d0
......@@ -420,6 +420,22 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
});
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::GEMV(
bool trans_a, int M, int N, platform::float16 alpha,
const platform::float16 *A, const platform::float16 *B,
platform::float16 beta, platform::float16 *C) const {
// Because cublas doesn't support half gemv, we use cublasHgemm to achieve it.
if (trans_a) {
this->template GEMM<platform::float16>(CblasNoTrans, CblasNoTrans, 1, N, M,
alpha, B, A, beta, C);
} else {
this->template GEMM<platform::float16>(CblasNoTrans, CblasNoTrans, M, 1, N,
alpha, A, B, beta, C);
}
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::BatchedGEMM(
......@@ -479,6 +495,19 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
}
}
template <>
template <>
inline void Blas<platform::CUDADeviceContext>::BatchedGEMM(
CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
platform::float16 alpha, const platform::float16 **A,
const platform::float16 **B, platform::float16 beta, platform::float16 **C,
int batchCount) const {
for (int k = 0; k < batchCount; ++k) {
this->template GEMM<platform::float16>(transA, transB, M, N, K, alpha, A[k],
B[k], beta, C[k]);
}
}
template <>
template <typename T>
void Blas<platform::CUDADeviceContext>::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo,
......
......@@ -17,10 +17,12 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plf = paddle::platform;
REGISTER_OP_CUDA_KERNEL(matmul_v2,
ops::MatMulV2Kernel<plf::CUDADeviceContext, float>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
matmul_v2, ops::MatMulV2Kernel<plf::CUDADeviceContext, float>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, double>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::float16>);
REGISTER_OP_CUDA_KERNEL(
matmul_v2_grad, ops::MatMulV2GradKernel<plf::CUDADeviceContext, float>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, double>);
ops::MatMulV2GradKernel<plf::CUDADeviceContext, double>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::float16>);
......@@ -163,17 +163,20 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
if (trans_y) {
const int M = Y->numel() / N;
VLOG(3) << "MatMul's case 2";
blas.GEMV(false, M, N, 1., y_data, x_data, 0., Out->data<T>());
blas.GEMV(false, M, N, static_cast<T>(1), y_data, x_data,
static_cast<T>(0), Out->data<T>());
} else {
const int M = y_dims[y_ndim - 1];
const int batch_size = Y->numel() / (M * N);
if (batch_size == 1) {
VLOG(3) << "MatMul's case 3";
blas.GEMV(true, N, M, 1., y_data, x_data, 0., Out->data<T>());
blas.GEMV(true, N, M, static_cast<T>(1), y_data, x_data,
static_cast<T>(0), Out->data<T>());
} else {
VLOG(3) << "MatMul's case 4";
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, 1.0f, y_data,
x_data, 0, Out->data<T>(), batch_size, M * N, 0);
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast<T>(1),
y_data, x_data, static_cast<T>(0), Out->data<T>(),
batch_size, M * N, 0);
}
}
return;
......@@ -205,16 +208,19 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
const int batch_size = X->numel() / (M * N);
if (batch_size == 1) {
VLOG(3) << "MatMul's case 5";
blas.GEMV(true, N, M, 1.0f, x_data, y_data, 0.0f, Out->data<T>());
blas.GEMV(true, N, M, static_cast<T>(1), x_data, y_data,
static_cast<T>(0), Out->data<T>());
} else {
VLOG(3) << "MatMul's case 6";
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, 1.0f, x_data,
y_data, 0, Out->data<T>(), batch_size, M * N, 0);
blas.BatchedGEMM(CblasTrans, CblasNoTrans, M, 1, N, static_cast<T>(1),
x_data, y_data, static_cast<T>(0), Out->data<T>(),
batch_size, M * N, 0);
}
} else {
const int M = X->numel() / N;
VLOG(3) << "MatMul's case 7";
blas.GEMV(false, M, N, 1.0f, x_data, y_data, 0.0f, Out->data<T>());
blas.GEMV(false, M, N, static_cast<T>(1), x_data, y_data,
static_cast<T>(0), Out->data<T>());
}
return;
}
......@@ -263,37 +269,38 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
if (x_batch_size == 1 && y_batch_size == 1) {
VLOG(3) << "MatMul's case 8";
blas.GEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data,
y_data, 0.0f, Out->data<T>());
trans_y ? CblasTrans : CblasNoTrans, M, N, K, static_cast<T>(1),
x_data, y_data, static_cast<T>(0), Out->data<T>());
} else if (x_batch_size == 1) {
if (M == 1 && trans_y) {
VLOG(3) << "MatMul's case 9";
blas.GEMV(false, y_batch_size * N, K, 1.0f, y_data, x_data, 0.0f,
Out->data<T>());
blas.GEMV(false, y_batch_size * N, K, static_cast<T>(1), y_data, x_data,
static_cast<T>(0), Out->data<T>());
} else {
VLOG(3) << "MatMul's case 10";
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f,
x_data, y_data, 0, Out->data<T>(), out_batch_size, 0,
K * N);
trans_y ? CblasTrans : CblasNoTrans, M, N, K,
static_cast<T>(1), x_data, y_data, static_cast<T>(0),
Out->data<T>(), out_batch_size, 0, K * N);
}
} else if (y_batch_size == 1) {
if (!trans_x) {
VLOG(3) << "MatMul's case 11";
blas.GEMM(CblasNoTrans, trans_y ? CblasTrans : CblasNoTrans,
x_batch_size * M, N, K, 1.0f, x_data, y_data, 0.0f,
Out->data<T>());
x_batch_size * M, N, K, static_cast<T>(1), x_data, y_data,
static_cast<T>(0), Out->data<T>());
} else {
VLOG(3) << "MatMul's case 12";
blas.BatchedGEMM(CblasTrans, trans_y ? CblasTrans : CblasNoTrans, M, N, K,
1.0f, x_data, y_data, 0, Out->data<T>(), out_batch_size,
M * K, 0);
static_cast<T>(1), x_data, y_data, static_cast<T>(0),
Out->data<T>(), out_batch_size, M * K, 0);
}
} else if (!is_broadcast_dims) {
VLOG(3) << "MatMul's case 13";
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f, x_data,
y_data, 0, Out->data<T>(), out_batch_size, M * K, K * N);
trans_y ? CblasTrans : CblasNoTrans, M, N, K,
static_cast<T>(1), x_data, y_data, static_cast<T>(0),
Out->data<T>(), out_batch_size, M * K, K * N);
} else {
// in the case, can't use stridedgemm
std::vector<const T*> x_ptr(out_batch_size);
......@@ -314,9 +321,9 @@ void MatMulFunction(const Tensor* X, const Tensor* Y,
}
VLOG(3) << "MatMul's case 14";
blas.BatchedGEMM(trans_x ? CblasTrans : CblasNoTrans,
trans_y ? CblasTrans : CblasNoTrans, M, N, K, 1.0f,
x_ptr.data(), y_ptr.data(), 0.0f, out_ptr.data(),
out_batch_size);
trans_y ? CblasTrans : CblasNoTrans, M, N, K,
static_cast<T>(1), x_ptr.data(), y_ptr.data(),
static_cast<T>(0), out_ptr.data(), out_batch_size);
}
}
......
......@@ -65,15 +65,21 @@ class TestMatMulV2Op(OpTest):
self.y_shape = (100, )
self.trans_x = False
self.trans_y = False
def init_kernel_type(self):
self.dtype = "float64"
def setUp(self):
self.init_kernel_type()
self.config()
self.op_type = "matmul_v2"
x = np.random.random(self.x_shape).astype(self.dtype)
y = np.random.random(self.y_shape).astype(self.dtype)
# -0.1 ~ 0.1
x = -0.1 + 0.2 * x
y = -0.1 + 0.2 * y
result = reference_matmul(x, y, self.trans_x, self.trans_y)
result = result.astype(self.dtype)
self.inputs = {
'X': x,
'Y': y,
......@@ -98,7 +104,6 @@ class TestMatMuklOp2(TestMatMulV2Op):
self.y_shape = (1, 3, 2, 100)
self.trans_x = False
self.trans_y = True
self.dtype = "float64"
class TestMatMuklOp3(TestMatMulV2Op):
......@@ -111,7 +116,6 @@ class TestMatMuklOp3(TestMatMulV2Op):
self.y_shape = (1, 1, 100, 2)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp4(TestMatMulV2Op):
......@@ -124,7 +128,6 @@ class TestMatMuklOp4(TestMatMulV2Op):
self.y_shape = (1, 2, 100, 2)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp5(TestMatMulV2Op):
......@@ -133,11 +136,10 @@ class TestMatMuklOp5(TestMatMulV2Op):
"""
def config(self):
self.x_shape = (1, 1, 100, 2)
self.x_shape = (1, 1, 100, 1)
self.y_shape = (100, )
self.trans_x = True
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp6(TestMatMulV2Op):
......@@ -150,7 +152,6 @@ class TestMatMuklOp6(TestMatMulV2Op):
self.y_shape = (100, )
self.trans_x = True
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp7(TestMatMulV2Op):
......@@ -163,7 +164,6 @@ class TestMatMuklOp7(TestMatMulV2Op):
self.y_shape = (100, )
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp8(TestMatMulV2Op):
......@@ -176,7 +176,6 @@ class TestMatMuklOp8(TestMatMulV2Op):
self.y_shape = (1, 1, 100, 2)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp9(TestMatMulV2Op):
......@@ -189,7 +188,6 @@ class TestMatMuklOp9(TestMatMulV2Op):
self.y_shape = (2, 1, 2, 100)
self.trans_x = False
self.trans_y = True
self.dtype = "float64"
class TestMatMuklOp10(TestMatMulV2Op):
......@@ -198,11 +196,10 @@ class TestMatMuklOp10(TestMatMulV2Op):
"""
def config(self):
self.x_shape = (1, 1, 2, 100)
self.y_shape = (1, 2, 100, 2)
self.x_shape = (1, 1, 25, 4)
self.y_shape = (1, 2, 4, 25)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp11(TestMatMulV2Op):
......@@ -215,7 +212,6 @@ class TestMatMuklOp11(TestMatMulV2Op):
self.y_shape = (1, 1, 100, 2)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp12(TestMatMulV2Op):
......@@ -224,11 +220,10 @@ class TestMatMuklOp12(TestMatMulV2Op):
"""
def config(self):
self.x_shape = (2, 1, 100, 2)
self.y_shape = (1, 1, 100, 2)
self.x_shape = (2, 1, 4, 25)
self.y_shape = (1, 1, 4, 25)
self.trans_x = True
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp13(TestMatMulV2Op):
......@@ -237,11 +232,10 @@ class TestMatMuklOp13(TestMatMulV2Op):
"""
def config(self):
self.x_shape = (2, 2, 100, 2)
self.y_shape = (2, 2, 100, 2)
self.x_shape = (2, 2, 2, 50)
self.y_shape = (2, 2, 2, 50)
self.trans_x = True
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp14(TestMatMulV2Op):
......@@ -254,7 +248,6 @@ class TestMatMuklOp14(TestMatMulV2Op):
self.y_shape = (1, 2, 2, 100, 2)
self.trans_x = True
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp15(TestMatMulV2Op):
......@@ -267,7 +260,6 @@ class TestMatMuklOp15(TestMatMulV2Op):
self.y_shape = (1, 2, 2, 100, 1)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp16(TestMatMulV2Op):
......@@ -277,10 +269,9 @@ class TestMatMuklOp16(TestMatMulV2Op):
def config(self):
self.x_shape = (100)
self.y_shape = (1, 2, 2, 100, 1)
self.y_shape = (1, 2, 2, 100, 2)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
class TestMatMuklOp17(TestMatMulV2Op):
......@@ -293,7 +284,54 @@ class TestMatMuklOp17(TestMatMulV2Op):
self.y_shape = (100)
self.trans_x = False
self.trans_y = False
self.dtype = "float64"
#--------------------test matmul fp16--------------------
def create_test_fp16_class(parent, atol=0.001, max_relative_error=1.0):
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestMatMulOpFp16Case(parent):
def init_kernel_type(self):
self.dtype = np.float16
def test_check_output(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_output_with_place(place, atol=atol)
def test_check_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
self.check_grad_with_place(
place, ['X', 'Y'],
'Out',
max_relative_error=max_relative_error)
cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
TestMatMulOpFp16Case.__name__ = cls_name
globals()[cls_name] = TestMatMulOpFp16Case
create_test_fp16_class(TestMatMulV2Op)
create_test_fp16_class(TestMatMuklOp2)
create_test_fp16_class(TestMatMuklOp3)
create_test_fp16_class(TestMatMuklOp4)
create_test_fp16_class(TestMatMuklOp5)
create_test_fp16_class(TestMatMuklOp6)
create_test_fp16_class(TestMatMuklOp7)
create_test_fp16_class(TestMatMuklOp8)
create_test_fp16_class(TestMatMuklOp9)
create_test_fp16_class(TestMatMuklOp10)
create_test_fp16_class(TestMatMuklOp11)
create_test_fp16_class(TestMatMuklOp12)
create_test_fp16_class(TestMatMuklOp13)
create_test_fp16_class(TestMatMuklOp14)
create_test_fp16_class(TestMatMuklOp15)
create_test_fp16_class(TestMatMuklOp16)
create_test_fp16_class(TestMatMuklOp17)
class TestMatMulV2API(unittest.TestCase):
......@@ -331,6 +369,17 @@ class TestMatMulV2API(unittest.TestCase):
y = paddle.to_tensor(input_y)
result = paddle.matmul(x, y)
def test_dygraph_fp16(self):
if core.is_compiled_with_cuda():
place = core.CUDAPlace(0)
if core.is_float16_supported(place):
with fluid.dygraph.guard(place):
input_x = np.random.random([4, 3]).astype("float16")
input_y = np.random.random([3, 4]).astype("float16")
x = paddle.to_tensor(input_x)
y = paddle.to_tensor(input_y)
result = paddle.matmul(x, y)
if __name__ == "__main__":
unittest.main()
......@@ -156,8 +156,8 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(val, name, ['float32', 'float64'],
'matmul')
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'matmul')
__check_input(x, y)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册