// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/matmul_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/xpu/xpu_api_wrapper.h" namespace phi { template void MatmulKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, bool transpose_x, bool transpose_y, DenseTensor* out) { using XPUType = typename XPUTypeTrait::Type; dev_ctx.template Alloc(out); const XPUType* x_ptr = reinterpret_cast(x.data()); const XPUType* y_ptr = reinterpret_cast(y.data()); XPUType* out_ptr = reinterpret_cast(out->data()); auto x_dims = x.dims(); auto y_dims = y.dims(); XpuFcInfo fc_info; GetFCInfo(x_dims, y_dims, transpose_x, transpose_y, &fc_info); xpu::Context* xpu_ctx = dev_ctx.x_context(); MatMulXPUFunction(xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f); } template void MatmulWithFlattenKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, int x_num_col_dims, int y_num_col_dims, DenseTensor* out) { using XPUType = typename XPUTypeTrait::Type; const DenseTensor x_matrix = x.dims().size() > 2 ? phi::ReshapeToMatrix(x, x_num_col_dims) : x; const DenseTensor y_matrix = y.dims().size() > 2 ? phi::ReshapeToMatrix(y, y_num_col_dims) : y; dev_ctx.template Alloc(out); const XPUType* x_ptr = reinterpret_cast(x_matrix.data()); const XPUType* y_ptr = reinterpret_cast(y_matrix.data()); XPUType* out_ptr = reinterpret_cast(out->data()); bool trans_a = false; bool trans_b = false; auto x_dims = x_matrix.dims(); auto y_dims = y_matrix.dims(); phi::XpuFcInfo fc_info; phi::GetFCInfo(x_dims, y_dims, trans_a, trans_b, &fc_info); xpu::Context* xpu_ctx = dev_ctx.x_context(); phi::MatMulXPUFunction( xpu_ctx, x_ptr, y_ptr, out_ptr, fc_info, 1.0f); } } // namespace phi PD_REGISTER_KERNEL( matmul, XPU, ALL_LAYOUT, phi::MatmulKernel, float, phi::dtype::float16) {} PD_REGISTER_KERNEL(matmul_with_flatten, XPU, ALL_LAYOUT, phi::MatmulWithFlattenKernel, float, phi::dtype::float16) {}