From 6916215e3718bd494f596790354788cd517cfa0e Mon Sep 17 00:00:00 2001 From: zhangyikun02 <48021248+zhangyk0314@users.noreply.github.com> Date: Fri, 4 Nov 2022 10:32:54 +0800 Subject: [PATCH] matmul_v2 support new case and fix masked_select bug for xpu, test=kunlun (#47370) --- .../phi/kernels/xpu/masked_select_kernel.cc | 18 ++++++----- paddle/phi/kernels/xpu/matmul_grad_kernel.cc | 18 +++++++++++ paddle/phi/kernels/xpu/xpu_api_wrapper.h | 32 ++++++++++++++++--- .../unittests/xpu/test_matmul_v2_op_xpu.py | 24 ++++++++++++++ 4 files changed, 79 insertions(+), 13 deletions(-) diff --git a/paddle/phi/kernels/xpu/masked_select_kernel.cc b/paddle/phi/kernels/xpu/masked_select_kernel.cc index 43b8d2cba2..0f142e852a 100644 --- a/paddle/phi/kernels/xpu/masked_select_kernel.cc +++ b/paddle/phi/kernels/xpu/masked_select_kernel.cc @@ -62,14 +62,16 @@ void MaskedSelectKernel(const Context& dev_ctx, auto input_shape = vectorize(input_dim); auto mask_shape = vectorize(mask_dim); - PADDLE_ENFORCE_XDNN_SUCCESS(xpu::masked_select(dev_ctx.x_context(), - input_data, - mask_data, - out_data, - input_shape, - mask_shape, - out_size_cpu), - "masked_select"); + if (out_size_cpu > 0) { + PADDLE_ENFORCE_XDNN_SUCCESS(xpu::masked_select(dev_ctx.x_context(), + input_data, + mask_data, + out_data, + input_shape, + mask_shape, + out_size_cpu), + "masked_select"); + } } } // namespace phi diff --git a/paddle/phi/kernels/xpu/matmul_grad_kernel.cc b/paddle/phi/kernels/xpu/matmul_grad_kernel.cc index dfdac5a552..07f93dc2d6 100644 --- a/paddle/phi/kernels/xpu/matmul_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/matmul_grad_kernel.cc @@ -56,6 +56,15 @@ void MatmulGradKernel(const Context& dev_ctx, : reinterpret_cast(dx->data()); XPUType* c_2 = (dy == NULL) ? reinterpret_cast(NULL) : reinterpret_cast(dy->data()); + + if (info_forward.is_x_need_broadcast) { + XPUType* new_c_1 = nullptr; + new_c_1 = RAII_GUARD.alloc_l3_or_gm( + info_forward.bs * info_forward.m * info_forward.k); + PADDLE_ENFORCE_XDNN_NOT_NULL(new_c_1); + c_1 = new_c_1; + } + XpuFcInfo info_dx; XpuFcInfo info_dy; std::tuple(xpu_ctx, a_1, b_1, c_1, info_dx, 1.0f); + if (info_forward.is_x_need_broadcast) { + int r = xpu::reduce_sum( + xpu_ctx, + c_1, + reinterpret_cast(dx->data()), + {info_forward.bs, info_forward.m, info_forward.k}, + {0}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum"); + } } if (dy) { MatMulXPUFunction(xpu_ctx, a_2, b_2, c_2, info_dy, 1.0f); diff --git a/paddle/phi/kernels/xpu/xpu_api_wrapper.h b/paddle/phi/kernels/xpu/xpu_api_wrapper.h index 0ebed7f449..8fefbd84c6 100644 --- a/paddle/phi/kernels/xpu/xpu_api_wrapper.h +++ b/paddle/phi/kernels/xpu/xpu_api_wrapper.h @@ -58,6 +58,7 @@ struct XpuFcInfo { float* max_x; float* max_y; float* max_out; + bool is_x_need_broadcast; XpuFcInfo() : bs(0), m(0), @@ -70,7 +71,8 @@ struct XpuFcInfo { stride_out(0), max_x(nullptr), max_y(nullptr), - max_out(nullptr) {} + max_out(nullptr), + is_x_need_broadcast(false) {} void InitFcInfo(int bs, int m, int n, @@ -145,8 +147,12 @@ static void GetFCInfo(const phi::DDim& x_dims, y_dims.to_str(), mat_dim_a.trans_, mat_dim_b.trans_)); - mat_dim_b.height_ *= mat_dim_b.batch_size_; - mat_dim_b.batch_size_ = 0; + if (mat_dim_a.width_ == mat_dim_b.batch_size_ * mat_dim_b.height_) { + mat_dim_b.height_ *= mat_dim_b.batch_size_; + mat_dim_b.batch_size_ = 0; + } else { + info->is_x_need_broadcast = true; + } } if (mat_dim_a.width_ == mat_dim_b.height_) { @@ -171,7 +177,7 @@ static void GetFCInfo(const phi::DDim& x_dims, info->m = mat_dim_a.height_; info->n = mat_dim_b.width_; info->k = mat_dim_a.width_; - info->bs = mat_dim_a.batch_size_; + info->bs = std::max(mat_dim_a.batch_size_, mat_dim_b.batch_size_); info->trans_x = trans_x; info->trans_y = trans_y; @@ -406,6 +412,7 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx, float* max_x = fcinfo.max_x; float* max_y = fcinfo.max_y; float* max_out = fcinfo.max_out; + bool is_x_need_broadcast = fcinfo.is_x_need_broadcast; if (batch_size <= 1) { fc_api(xpu_ctx, @@ -428,6 +435,19 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx, nullptr, xpu::Activation_t::LINEAR); } else { + const XPUType* x_data = reinterpret_cast(x); + if (is_x_need_broadcast) { + XPUType* x_broadcast_data = nullptr; + xpu::ctx_guard RAII_GUARD(xpu_ctx); + x_broadcast_data = RAII_GUARD.alloc_l3_or_gm(batch_size * m * k); + PADDLE_ENFORCE_XDNN_NOT_NULL(x_broadcast_data); + std::vector x_shape = {1, m, k}; + std::vector new_x_shape = {batch_size, m, k}; + int r = xpu::broadcast( + xpu_ctx, x_data, x_broadcast_data, x_shape, new_x_shape); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast"); + x_data = x_broadcast_data; + } // batch matmul fc_batch_api(xpu_ctx, // Context* ctx, batch_size, // int batch_size, @@ -437,7 +457,7 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx, n, // int n, k, // int k, alpha, // float alpha, - reinterpret_cast(x), // const TX* x, + x_data, // const TX* x, ldx, // int stride_a, reinterpret_cast(y), // const TW* w, ldy, // int stride_b, @@ -554,6 +574,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx, nullptr, max_dout, nullptr); + dy_shape.is_x_need_broadcast = dout_shape.is_x_need_broadcast; dy_a = x, dy_b = dout_new; } else if (trans_y) { // dx = dout * y @@ -600,6 +621,7 @@ MatmulGradFcInfo(xpu::Context* xpu_ctx, nullptr, max_dout, nullptr); + dy_shape.is_x_need_broadcast = dout_shape.is_x_need_broadcast; dy_a = x, dy_b = dout_new; } std::tuple diff --git a/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py index 3e873a965f..c2a1ab4ee0 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py @@ -294,6 +294,30 @@ class XPUTestMatmulV2Op(XPUOpTestWrapper): self.trans_x = False self.trans_y = False + class TestMatMulOp19(TestMatMulV2Op): + """ + case 19 : (x.ndim <= 2) && (y.ndim >= 3), + x need to broadcast and trans_y is false + """ + + def config(self): + self.x_shape = (10, 20) + self.y_shape = (2, 20, 4) + self.trans_x = False + self.trans_y = False + + class TestMatMulOp20(TestMatMulV2Op): + """ + case 20 : (x.ndim <= 2) && (y.ndim >= 3), + x need to broadcast and trans_y is false + """ + + def config(self): + self.x_shape = (20, 10) + self.y_shape = (2, 20, 4) + self.trans_x = True + self.trans_y = False + support_types = get_xpu_op_support_types('matmul_v2') for stype in support_types: -- GitLab