From 7634a18a5fbf2b96b174e73934415b6c44273a59 Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Sun, 23 Apr 2023 11:14:40 +0800 Subject: [PATCH] [XPU] fc use int_with_ll_t (#53183) * [XPU] fc use int_with_ll_t * fix test_unbind_op_xpu --- paddle/phi/kernels/xpu/bmm_grad_kernel.cc | 2 ++ paddle/phi/kernels/xpu/bmm_kernel.cc | 2 ++ paddle/phi/kernels/xpu/xpu_api_wrapper.h | 9 +++++++-- test/xpu/test_unbind_op_xpu.py | 9 +++++++-- 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/paddle/phi/kernels/xpu/bmm_grad_kernel.cc b/paddle/phi/kernels/xpu/bmm_grad_kernel.cc index 8cf8c2cbb31..5f4a0d9a99d 100644 --- a/paddle/phi/kernels/xpu/bmm_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/bmm_grad_kernel.cc @@ -33,6 +33,8 @@ void MatMul(const Context& dev_ctx, MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); } else if (fccal_type == XPUFCCalcType::FC_FLOAT) { MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); + } else if (fccal_type == XPUFCCalcType::FC_INT_WITH_LL) { + MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); } else { MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); } diff --git a/paddle/phi/kernels/xpu/bmm_kernel.cc b/paddle/phi/kernels/xpu/bmm_kernel.cc index 49571d87783..b68a5fc3c00 100644 --- a/paddle/phi/kernels/xpu/bmm_kernel.cc +++ b/paddle/phi/kernels/xpu/bmm_kernel.cc @@ -68,6 +68,8 @@ void BmmKernel(const Context& dev_ctx, MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); } else if (fccal_type == XPUFCCalcType::FC_FLOAT) { MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); + } else if (fccal_type == XPUFCCalcType::FC_INT_WITH_LL) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); } else { MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); } diff --git a/paddle/phi/kernels/xpu/xpu_api_wrapper.h b/paddle/phi/kernels/xpu/xpu_api_wrapper.h index b978cb8a2aa..5bbe1163552 100644 --- a/paddle/phi/kernels/xpu/xpu_api_wrapper.h +++ b/paddle/phi/kernels/xpu/xpu_api_wrapper.h @@ -30,6 +30,7 @@ enum XPUFCCalcType { FC_INT16 = 0, FC_INT32, FC_FLOAT, + FC_INT_WITH_LL, }; template @@ -41,6 +42,8 @@ XPUFCCalcType FCCalcType() { return XPUFCCalcType::FC_INT32; } else if (std::getenv("XPU_PADDLE_FC_LOCAL_INT16") != nullptr) { return XPUFCCalcType::FC_FLOAT; + } else if (std::getenv("XPU_PADDLE_FC_INT_WITH_LL") != nullptr) { + return XPUFCCalcType::FC_INT_WITH_LL; } return XPUFCCalcType::FC_INT16; } @@ -387,15 +390,17 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx, using XPUType = typename XPUTypeTrait::Type; int fccal_type = FCCalcType(); - decltype(&xpu_fc_wrapper) fc_api_list[3] = { + decltype(&xpu_fc_wrapper) fc_api_list[4] = { &xpu_fc_wrapper, &xpu_fc_wrapper, &xpu_fc_wrapper, + &xpu_fc_wrapper, }; - decltype(&xpu_fc_batch_wrapper) fc_batch_api_list[3] = { + decltype(&xpu_fc_batch_wrapper) fc_batch_api_list[4] = { &xpu_fc_batch_wrapper, &xpu_fc_batch_wrapper, &xpu_fc_batch_wrapper, + &xpu_fc_batch_wrapper, }; auto fc_api = fc_api_list[fccal_type]; diff --git a/test/xpu/test_unbind_op_xpu.py b/test/xpu/test_unbind_op_xpu.py index dc8ea7ae6bc..87c1d7d25c6 100644 --- a/test/xpu/test_unbind_op_xpu.py +++ b/test/xpu/test_unbind_op_xpu.py @@ -41,7 +41,7 @@ class XPUTestUnbindOP(XPUOpTestWrapper): x_1 = paddle.static.data(shape=[2, 3], dtype=self.dtype, name='x_1') [out_0, out_1] = tensor.unbind(input=x_1, axis=0) input_1 = np.random.random([2, 3]).astype(self.dtype) - axis = paddle.static.data(shape=[1], dtype='int32', name='axis') + axis = paddle.static.data(shape=[], dtype='int32', name='axis') exe = fluid.Executor(place=self.place) [res_1, res_2] = exe.run( @@ -80,7 +80,7 @@ class XPUTestUnbindOP(XPUOpTestWrapper): x_1 = paddle.static.data(shape=[2, 3], dtype=self.dtype, name='x_1') [out_0, out_1] = paddle.unbind(input=x_1, axis=0) input_1 = np.random.random([2, 3]).astype(self.dtype) - axis = paddle.static.data(shape=[1], dtype='int32', name='axis') + axis = paddle.static.data(shape=[], dtype='int32', name='axis') exe = fluid.Executor(place=self.place) [res_1, res_2] = exe.run( @@ -196,6 +196,11 @@ class XPUTestUnbindOP(XPUOpTestWrapper): self.assertRaises(TypeError, test_table_Variable) + def test_invalid_axis(): + tensor.unbind(input=x, axis=2) + + self.assertRaises(ValueError, test_invalid_axis) + support_types = get_xpu_op_support_types('unbind') for stype in support_types: -- GitLab