diff --git a/paddle/fluid/operators/matmul_op_xpu.cc b/paddle/fluid/operators/matmul_op_xpu.cc index 8834e95758bf2f43c8ccda213b559d04c18556ce..f92cff2f6cd216493b12834bfd1744bb57e21460 100644 --- a/paddle/fluid/operators/matmul_op_xpu.cc +++ b/paddle/fluid/operators/matmul_op_xpu.cc @@ -159,23 +159,14 @@ static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out, "XPU fc_fusion kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); } else { - // batch matmul - int x_stride = mat_dim_a.stride_; - int y_stride = mat_dim_b.stride_; - int out_stride = m * n; - for (int i = 0; i < batch_size; ++i) { - const float *x_data = x->data() + x_stride * i; - const float *y_data = y->data() + y_stride * i; - float *out_data = data_c + out_stride * i; - int r = xpu::fc_fusion( - dev_ctx.x_context(), x_data, y_data, out_data, m, n, k, - mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx, - ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR); - PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, - platform::errors::External( - "XPU fc_fusion kernel return wrong value[%d %s]", r, - XPUAPIErrorMsg[r])); - } + int r = xpu::fc_batched( + dev_ctx.x_context(), batch_size, mat_dim_a.trans_, mat_dim_b.trans_, m, + n, k, alpha, x->data(), mat_dim_a.stride_, y->data(), + mat_dim_b.stride_, 0.0, data_c, m * n, nullptr, nullptr); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU fc_batched kernel return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); } } diff --git a/paddle/fluid/operators/matmul_v2_op_xpu.cc b/paddle/fluid/operators/matmul_v2_op_xpu.cc index 765a380c6b84ff4cf0da7a36bf9e1050ff0a9b73..dbb1d7bfb0a3d9fb2d8727b48061da954728da01 100644 --- a/paddle/fluid/operators/matmul_v2_op_xpu.cc +++ b/paddle/fluid/operators/matmul_v2_op_xpu.cc @@ -79,22 +79,14 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, "XPU fc_fusion kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); } else { - // batch matmul - int x_stride = mat_dim_a.stride_; - int y_stride = mat_dim_b.stride_; - int out_stride = m * n; - for (int i = 0; i < batch_size; ++i) { - const float* x_data = x->data() + x_stride * i; - const float* y_data = y->data() + y_stride * i; - float* out_data = data_c + out_stride * i; - int r = xpu::fc( - dev_ctx.x_context(), x_data, y_data, out_data, m, n, k, - mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr); - PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, - platform::errors::External( - "XPU fc_fusion kernel return wrong value[%d %s]", r, - XPUAPIErrorMsg[r])); - } + int r = xpu::fc_batched( + dev_ctx.x_context(), batch_size, mat_dim_a.trans_, mat_dim_b.trans_, m, + n, k, 1.0, x->data(), mat_dim_a.stride_, y->data(), + mat_dim_b.stride_, 0.0, data_c, m * n, nullptr, nullptr); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU fc_batched kernel return wrong value[%d %s]", r, + XPUAPIErrorMsg[r])); } }