// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/triangular_solve_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/expand_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/common_shape.h" namespace phi { template void TriangularSolveKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, bool upper, bool transpose, bool unitriangular, DenseTensor* out) { // get broadcast dim std::vector x_bst_dims_vec; std::vector y_bst_dims_vec; std::tie(x_bst_dims_vec, y_bst_dims_vec) = funcs::MatrixGetBroadcastDims(x, y); int x_bst_ndim = x_bst_dims_vec.size(); int y_bst_ndim = y_bst_dims_vec.size(); // Tensor broadcast to 'out' and temp 'x_bst' ScalarArray x_bst_dims(x_bst_dims_vec); DenseTensor x_bst = phi::Empty(dev_ctx, x_bst_dims); const T* x_bst_data = x_bst.data(); ExpandKernel(dev_ctx, x, x_bst_dims, &x_bst); out->Resize(phi::make_ddim(y_bst_dims_vec)); T* out_data = dev_ctx.template Alloc(out); ScalarArray y_bst_dims(y_bst_dims_vec); ExpandKernel(dev_ctx, y, y_bst_dims, out); // Calculate use blas library int M = static_cast(y_bst_dims_vec[y_bst_ndim - 2]); int N = static_cast(y_bst_dims_vec[y_bst_ndim - 1]); int batch_size = 1; for (int i = 0; i < x_bst_ndim - 2; i++) { batch_size *= x_bst_dims_vec[i]; } auto blas = phi::funcs::GetBlas(dev_ctx); for (int i = 0; i < batch_size; i++) { blas.TRSM(CblasLeft, upper ? CblasUpper : CblasLower, transpose ? CblasTrans : CblasNoTrans, unitriangular ? CblasUnit : CblasNonUnit, M, N, T(1), x_bst_data + i * M * M, std::max(1, M), out_data + i * N * M, std::max(1, N)); } } } // namespace phi PD_REGISTER_KERNEL(triangular_solve, CPU, ALL_LAYOUT, phi::TriangularSolveKernel, float, double) {}