diff --git a/paddle/fluid/operators/triangular_solve_op.cc b/paddle/fluid/operators/triangular_solve_op.cc index 9233917b0931b98d30b736ec9b69fd68c0604d18..179f818104c9bcd9ca53420b00299984979e410e 100644 --- a/paddle/fluid/operators/triangular_solve_op.cc +++ b/paddle/fluid/operators/triangular_solve_op.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/triangular_solve_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/solve_op.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -22,58 +25,6 @@ class TriangularSolveOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "TriangularSolve"); - OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "TriangularSolve"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "TriangularSolve"); - - auto x_dims = ctx->GetInputDim("X"); - auto y_dims = ctx->GetInputDim("Y"); - - auto x_dims_n = x_dims.size(); - auto y_dims_n = y_dims.size(); - - PADDLE_ENFORCE_GE( - x_dims_n, 2, platform::errors::InvalidArgument( - "The input tensor X's dimensions of TriangularSolveOp " - "should be >= 2. But received X's " - "dimensions = %d, X's shape = [%s]", - x_dims.size(), x_dims)); - - PADDLE_ENFORCE_GE( - y_dims_n, 2, platform::errors::InvalidArgument( - "The input tensor Y's dimensions of TriangularSolveOp " - "should be >=2. But received Y's " - "dimensions = %d, Y's shape = [%s]", - y_dims.size(), y_dims)); - - PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2], x_dims[x_dims_n - 1], - platform::errors::InvalidArgument( - "The inner-most 2 dimensions of Input(X) all should " - "be square matrices " - "But received X's shape[-2] = %d and shape[-1] = %d.", - x_dims[x_dims_n - 2], x_dims[x_dims_n - 1])); - - std::vector x_dims_vec = phi::vectorize(x_dims); - std::vector y_dims_vec = phi::vectorize(y_dims); - - std::vector x_dims_vec_cut(x_dims_vec.begin(), - x_dims_vec.end() - 2); - std::vector y_dims_vec_cut(y_dims_vec.begin(), - y_dims_vec.end() - 2); - - std::vector expand_batch_portion = - get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut); - - std::vector y_broadcast_dims({expand_batch_portion}); - y_broadcast_dims.insert(y_broadcast_dims.end(), {y_dims_vec[y_dims_n - 2], - y_dims_vec[y_dims_n - 1]}); - - // dim of 'Out' is the same with 'Y' after broadcast - ctx->SetOutputDim("Out", phi::make_ddim(y_broadcast_dims)); - ctx->ShareLoD("X", /*->*/ "Out"); - } - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const { return framework::OpKernelType( @@ -168,20 +119,15 @@ class TriangularSolveOpGradMaker : public framework::SingleGradOpMaker { } // namespace paddle namespace ops = paddle::operators; + +DELCARE_INFER_SHAPE_FUNCTOR(triangular_solve, TriangularSolveInferShapeFunctor, + PT_INFER_META(phi::TriangularSolveInferMeta)); + REGISTER_OPERATOR(triangular_solve, ops::TriangularSolveOp, ops::TriangularSolveOpMaker, ops::TriangularSolveOpInferVarType, ops::TriangularSolveOpGradMaker, - ops::TriangularSolveOpGradMaker); + ops::TriangularSolveOpGradMaker, + TriangularSolveInferShapeFunctor); REGISTER_OPERATOR(triangular_solve_grad, ops::TriangularSolveGradOp); - -REGISTER_OP_CPU_KERNEL( - triangular_solve, - ops::TriangularSolveKernel, - ops::TriangularSolveKernel); - -REGISTER_OP_CPU_KERNEL( - triangular_solve_grad, - ops::TriangularSolveGradKernel, - ops::TriangularSolveGradKernel); diff --git a/paddle/fluid/operators/triangular_solve_op.cu b/paddle/fluid/operators/triangular_solve_op.cu deleted file mode 100644 index 7df98517e8418905f0f8c8ce603762967a8b5f38..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/triangular_solve_op.cu +++ /dev/null @@ -1,65 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/reduce_ops/reduce_op.h" -#include "paddle/fluid/operators/triangular_solve_op.h" - -namespace paddle { -namespace operators { - -template -class MatrixReduceSumFunctor { - public: - void operator()(const Tensor& in, Tensor* out, - const framework::ExecutionContext& ctx) { - // For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3] - // out_reduce_dim should be [0, 2] - const std::vector in_dims = phi::vectorize(in.dims()); - auto in_size = in_dims.size(); - const std::vector out_dims = phi::vectorize(out->dims()); - auto out_size = out_dims.size(); - - std::vector out_bst_dims(in_size); - - std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1); - std::copy(out_dims.data(), out_dims.data() + out_size, - out_bst_dims.data() + in_size - out_size); - - std::vector out_reduce_dims; - for (size_t idx = 0; idx <= in_size - 3; idx++) { - if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) { - out_reduce_dims.push_back(idx); - } - } - gpuStream_t stream = ctx.cuda_device_context().stream(); - TensorReduceImpl>( - ctx.cuda_device_context(), in, out, kps::IdentityFunctor(), - out_reduce_dims, stream); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - triangular_solve, - ops::TriangularSolveKernel, - ops::TriangularSolveKernel); - -REGISTER_OP_CUDA_KERNEL( - triangular_solve_grad, - ops::TriangularSolveGradKernel, - ops::TriangularSolveGradKernel); diff --git a/paddle/fluid/operators/triangular_solve_op.h b/paddle/fluid/operators/triangular_solve_op.h index 4e68add096ff28f5378b02689248c3957c1e8ae9..315847b4d800e46aea6c927f9b7055261b56e9bc 100644 --- a/paddle/fluid/operators/triangular_solve_op.h +++ b/paddle/fluid/operators/triangular_solve_op.h @@ -21,7 +21,6 @@ limitations under the License. */ #include "paddle/fluid/operators/solve_op.h" #include "paddle/fluid/operators/tril_triu_op.h" #include "paddle/phi/core/ddim.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/complex_functors.h" namespace paddle { @@ -30,10 +29,10 @@ namespace operators { using Tensor = framework::Tensor; template -static void triangular_solve(const DeviceContext& context, const Tensor& x, - const Tensor& y, Tensor* out, bool upper, +static void triangular_solve(const DeviceContext &context, const Tensor &x, + const Tensor &y, Tensor *out, bool upper, bool transpose, bool unitriangular) { - // Tensor broadcast use eigen + // Tensor broadcast use eigen library std::vector x_bst_dims_vec; std::vector y_bst_dims_vec; std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(x, y); @@ -64,15 +63,15 @@ static void triangular_solve(const DeviceContext& context, const Tensor& x, template class MatrixReduceSumFunctor { public: - void operator()(const Tensor& input, Tensor* output, - const framework::ExecutionContext& ctx); + void operator()(const Tensor &input, Tensor *output, + const framework::ExecutionContext &ctx); }; template class MatrixReduceSumFunctor { public: - void operator()(const Tensor& in, Tensor* out, - const framework::ExecutionContext& ctx) { + void operator()(const Tensor &in, Tensor *out, + const framework::ExecutionContext &ctx) { // For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3] // out_reduce_dim should be [0, 2] const std::vector in_dims = phi::vectorize(in.dims()); @@ -101,129 +100,5 @@ class MatrixReduceSumFunctor { } }; -template -class TriangularSolveKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto* x = ctx.Input("X"); - const auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); - - bool upper = ctx.template Attr("upper"); - bool transpose = ctx.template Attr("transpose"); - bool unitriangular = ctx.template Attr("unitriangular"); - - const auto& dev_ctx = ctx.template device_context(); - triangular_solve(dev_ctx, *x, *y, out, upper, transpose, - unitriangular); - } -}; - -template -class TriangularSolveGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto* x = ctx.Input("X"); - const auto* y = ctx.Input("Y"); - const auto* out = ctx.Input("Out"); - const auto* dout = - ctx.Input(framework::GradVarName("Out")); - - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - - bool upper = ctx.template Attr("upper"); - bool transpose = ctx.template Attr("transpose"); - bool unitriangular = ctx.template Attr("unitriangular"); - - auto& dev_ctx = ctx.template device_context(); - - std::vector x_bst_dims_vec; - std::vector y_bst_dims_vec; - std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(*x, *y); - - Tensor dy_bst(y->type()); - if (dy) { - dy->mutable_data(y->dims(), dev_ctx.GetPlace()); - dy_bst.Resize(phi::make_ddim(y_bst_dims_vec)); - dy_bst.mutable_data(dev_ctx.GetPlace()); - - // calculate x's conjugate for complex - Tensor x_conj(x->type()); - platform::ForRange x_for_range(dev_ctx, x->numel()); - phi::funcs::ConjFunctor x_functor( - x->data(), x->numel(), - x_conj.mutable_data(x->dims(), dev_ctx.GetPlace())); - x_for_range(x_functor); - - // reuse forward to get dy_bst, and the result has been broadcated. - triangular_solve(dev_ctx, x_conj, *dout, &dy_bst, upper, - !transpose, unitriangular); - - if (dy_bst.dims() == dy->dims()) { - framework::TensorCopy(dy_bst, dev_ctx.GetPlace(), dev_ctx, dy); - } else { - MatrixReduceSumFunctor functor; - functor(dy_bst, dy, ctx); - dy->Resize(y->dims()); - } - } - - Tensor dx_bst(x->type()); - if (dx) { - dx->mutable_data(x->dims(), dev_ctx.GetPlace()); - dx_bst.Resize(phi::make_ddim(x_bst_dims_vec)); - dx_bst.mutable_data(dev_ctx.GetPlace()); - - // calculate out's conjugate for complex - Tensor out_conj(out->type()); - platform::ForRange out_for_range(dev_ctx, out->numel()); - phi::funcs::ConjFunctor out_functor( - out->data(), out->numel(), - out_conj.mutable_data(out->dims(), dev_ctx.GetPlace())); - out_for_range(out_functor); - - auto blas = phi::funcs::GetBlas(ctx); - if (transpose) { - auto mat_dim_a = - phi::funcs::CreateMatrixDescriptor(out_conj.dims(), 0, false); - auto mat_dim_b = - phi::funcs::CreateMatrixDescriptor(dy_bst.dims(), 0, true); - blas.MatMul(out_conj, mat_dim_a, dy_bst, mat_dim_b, static_cast(-1), - &dx_bst, static_cast(0)); - } else { - auto mat_dim_a = - phi::funcs::CreateMatrixDescriptor(dy_bst.dims(), 0, false); - auto mat_dim_b = - phi::funcs::CreateMatrixDescriptor(out_conj.dims(), 0, true); - blas.MatMul(dy_bst, mat_dim_a, out_conj, mat_dim_b, static_cast(-1), - &dx_bst, static_cast(0)); - } - - Tensor dx_bst_upper(x->type()); - // get upper or lower triangular - dx_bst_upper.Resize(dx_bst.dims()); - dx_bst_upper.mutable_data(dev_ctx.GetPlace()); - - const auto& dims = dx_bst.dims(); - const auto H = dims[dims.size() - 2]; - const auto W = dims[dims.size() - 1]; - platform::ForRange x_for_range(dev_ctx, dx_bst.numel()); - TrilTriuCompute tril_triu_computer(dx_bst.data(), unitriangular, - !upper, H, W, - dx_bst_upper.data()); - x_for_range(tril_triu_computer); - - if (dx_bst_upper.dims() == dx->dims()) { - framework::TensorCopy(dx_bst_upper, dev_ctx.GetPlace(), dev_ctx, dx); - } else { - MatrixReduceSumFunctor functor; - functor(dx_bst_upper, dx, ctx); - dx->Resize(x->dims()); - } - } - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 03128e96a838f1abc30f71a096f3b4cb43071af9..c017e5864aa95f0307e01810297c22ec10b33533 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -274,6 +274,65 @@ void HuberLossInferMeta(const MetaTensor& input, out->share_lod(input); } +void TriangularSolveInferMeta(const MetaTensor& x, + const MetaTensor& y, + bool upper, + bool transpose, + bool unitriangular, + MetaTensor* out) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + + auto x_dims_n = x_dims.size(); + auto y_dims_n = y_dims.size(); + + PADDLE_ENFORCE_GE(x_dims_n, + 2, + phi::errors::InvalidArgument( + "The input tensor X's dimensions of TriangularSolveOp " + "should be >= 2. But received X's " + "dimensions = %d, X's shape = [%s]", + x_dims.size(), + x_dims)); + + PADDLE_ENFORCE_GE(y_dims_n, + 2, + phi::errors::InvalidArgument( + "The input tensor Y's dimensions of TriangularSolveOp " + "should be >=2. But received Y's " + "dimensions = %d, Y's shape = [%s]", + y_dims.size(), + y_dims)); + + PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2], + x_dims[x_dims_n - 1], + phi::errors::InvalidArgument( + "The inner-most 2 dimensions of Input(X) all should " + "be square matrices " + "But received X's shape[-2] = %d and shape[-1] = %d.", + x_dims[x_dims_n - 2], + x_dims[x_dims_n - 1])); + + std::vector x_dims_vec = phi::vectorize(x_dims); + std::vector y_dims_vec = phi::vectorize(y_dims); + + std::vector x_dims_vec_cut(x_dims_vec.begin(), x_dims_vec.end() - 2); + std::vector y_dims_vec_cut(y_dims_vec.begin(), y_dims_vec.end() - 2); + + std::vector expand_batch_portion = + funcs::MatrixGetBroadcastBatchPortion(x_dims_vec_cut, y_dims_vec_cut); + + std::vector y_broadcast_dims({expand_batch_portion}); + y_broadcast_dims.insert(y_broadcast_dims.end(), + {y_dims_vec[y_dims_n - 2], y_dims_vec[y_dims_n - 1]}); + + // dim of 'out' is the same with 'Y' after broadcast + out->set_dims(phi::make_ddim(y_broadcast_dims)); + out->set_dtype(y.dtype()); + out->set_layout(y.layout()); + out->share_lod(y); +} + void IndexSampleInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index f397c0def8a0bf1ba8fa22e535ca5137c873524e..976c17cd8d91e5649da8fc0a93dc0ea520c2435a 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -62,6 +62,13 @@ void HuberLossInferMeta(const MetaTensor& input_meta, MetaTensor* residual, MetaConfig config = MetaConfig()); +void TriangularSolveInferMeta(const MetaTensor& x, + const MetaTensor& y, + bool upper, + bool transpose, + bool unitriangular, + MetaTensor* out); + void IndexSampleInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 4ffa1826a29fa3904b959a1e8f2fd9ceb27511b4..e9108787082d071d67ea0012add837bc18592a0a 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -18,10 +18,11 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta) # NOTE: Some kernels depend on some targets that are not commonly used. # These targets are not suitable for common dependencies. # In this case, you need to manually generate them here. -set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel) +set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel) kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel) kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax) +kernel_library(triangular_solve_grad_kernel DEPS ${COMMON_KERNEL_DEPS} matrix_reduce) # auto parse and build kernel targets by cmake register_kernels(EXCLUDES ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS}) diff --git a/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc b/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..80b2015f7318ad9a8b46c77460ca70a17f801a6d --- /dev/null +++ b/paddle/phi/kernels/cpu/triangular_solve_grad_kernel.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(triangular_solve_grad, + CPU, + ALL_LAYOUT, + phi::TriangularSolveGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/triangular_solve_kernel.cc b/paddle/phi/kernels/cpu/triangular_solve_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..5aca5be12792387659b1c4db00e5d8ed98bc22dc --- /dev/null +++ b/paddle/phi/kernels/cpu/triangular_solve_kernel.cc @@ -0,0 +1,84 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/triangular_solve_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/common_shape.h" + +namespace phi { + +template +void TriangularSolveKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool upper, + bool transpose, + bool unitriangular, + DenseTensor* out) { + // get broadcast dim + std::vector x_bst_dims_vec; + std::vector y_bst_dims_vec; + std::tie(x_bst_dims_vec, y_bst_dims_vec) = + funcs::MatrixGetBroadcastDims(x, y); + int x_bst_ndim = x_bst_dims_vec.size(); + int y_bst_ndim = y_bst_dims_vec.size(); + + // Tensor broadcast to 'out' and temp 'x_bst' + ScalarArray x_bst_dims(x_bst_dims_vec); + DenseTensor x_bst = phi::Empty(dev_ctx, x_bst_dims); + const T* x_bst_data = x_bst.data(); + ExpandKernel(dev_ctx, x, x_bst_dims, &x_bst); + + out->Resize(phi::make_ddim(y_bst_dims_vec)); + T* out_data = dev_ctx.template Alloc(out); + ScalarArray y_bst_dims(y_bst_dims_vec); + ExpandKernel(dev_ctx, y, y_bst_dims, out); + + // Calculate use blas library + int M = static_cast(y_bst_dims_vec[y_bst_ndim - 2]); + int N = static_cast(y_bst_dims_vec[y_bst_ndim - 1]); + int batch_size = 1; + for (int i = 0; i < x_bst_ndim - 2; i++) { + batch_size *= x_bst_dims_vec[i]; + } + + auto blas = phi::funcs::GetBlas(dev_ctx); + for (int i = 0; i < batch_size; i++) { + blas.TRSM(CblasLeft, + upper ? CblasUpper : CblasLower, + transpose ? CblasTrans : CblasNoTrans, + unitriangular ? CblasUnit : CblasNonUnit, + M, + N, + T(1), + x_bst_data + i * M * M, + std::max(1, M), + out_data + i * N * M, + std::max(1, N)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(triangular_solve, + CPU, + ALL_LAYOUT, + phi::TriangularSolveKernel, + float, + double) {} diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index 8b8697b6df12cf16ff038c67a8f0079a9dbef5b8..02cba6009c4005020c95cd401a84f5b350b9560c 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -8,3 +8,4 @@ math_library(sequence2batch) math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) math_library(concat_and_split_functor DEPS dense_tensor) +math_library(matrix_reduce DEPS dense_tensor) diff --git a/paddle/phi/kernels/funcs/common_shape.h b/paddle/phi/kernels/funcs/common_shape.h index dce80caab72bf3d0b64c4144874bda2835a13671..139341536debf068b82704d5e7d70a3edbe045e0 100644 --- a/paddle/phi/kernels/funcs/common_shape.h +++ b/paddle/phi/kernels/funcs/common_shape.h @@ -140,6 +140,72 @@ inline bool CheckDims(const DDim &dims_x, const DDim &dims_y) { return true; } +// Just For Matrix OP, for example: +// x's dim = [5, 3, 2, M, M] ; y's dim = [3, 1, M, N] +// out [5, 3, 2], which is batch_size of matrix +static inline std::vector MatrixGetBroadcastBatchPortion( + std::vector x, std::vector y) { + size_t size_x = x.size(); + size_t size_y = y.size(); + size_t size = std::max(size_x, size_y); + std::vector batchPortion(size); + + ptrdiff_t i = (ptrdiff_t)size - 1; + for (; i >= 0; --i) { + ptrdiff_t offset = size - i - 1; + ptrdiff_t dim_x = size_x - offset - 1; + ptrdiff_t dim_y = size_y - offset - 1; + int64_t x_size = (dim_x >= 0) ? x[dim_x] : 1; + int64_t y_size = (dim_y >= 0) ? y[dim_y] : 1; + + PADDLE_ENFORCE_EQ( + (x_size == y_size || x_size == 1 || y_size == 1), + true, + phi::errors::PreconditionNotMet( + "The size of tensor x (%d) must match the size of tensor y " + "(%d) at non-singleton dimension %d.", + x_size, + y_size, + i)); + + batchPortion[i] = x_size != 1 ? x_size : y_size; + } + return batchPortion; +} + +// Just For Matrix OP, for example: +// x's dim = [5, 3, 2, M, M] ; y's dim = [3, 1, M, N] +// out shoule be [5, 3, 2, M, M] + [5, 3, 2, M, N], and [5, 3, 2] is +// batch_size of matrix +static inline std::tuple, std::vector> +MatrixGetBroadcastDims(const DenseTensor &x, const DenseTensor &y) { + std::vector x_dims_vec = phi::vectorize(x.dims()); + std::vector y_dims_vec = phi::vectorize(y.dims()); + + std::vector::const_iterator f1 = x_dims_vec.begin(); + std::vector::const_iterator l1 = x_dims_vec.end() - 2; + std::vector x_dims_vec_cut(f1, l1); + + std::vector::const_iterator f2 = y_dims_vec.begin(); + std::vector::const_iterator l2 = y_dims_vec.end() - 2; + std::vector y_dims_vec_cut(f2, l2); + + std::vector expand_batch_portion = + MatrixGetBroadcastBatchPortion(x_dims_vec_cut, y_dims_vec_cut); + + std::vector x_expand_size({expand_batch_portion}); + x_expand_size.insert(x_expand_size.end(), + {x_dims_vec[static_cast(x_dims_vec.size()) - 2], + x_dims_vec[static_cast(x_dims_vec.size()) - 1]}); + + std::vector y_expand_size({expand_batch_portion}); + y_expand_size.insert(y_expand_size.end(), + {y_dims_vec[static_cast(y_dims_vec.size()) - 2], + y_dims_vec[static_cast(y_dims_vec.size()) - 1]}); + + return std::make_tuple(x_expand_size, y_expand_size); +} + inline DDim GetOutputDims(const DDim &s_dims, const DDim &l_dims) { if (s_dims.size() > l_dims.size()) { return GetOutputDims(l_dims, s_dims); diff --git a/paddle/phi/kernels/funcs/matrix_reduce.cc b/paddle/phi/kernels/funcs/matrix_reduce.cc new file mode 100644 index 0000000000000000000000000000000000000000..849fd7a0075a89cedeab4d87c779931f2a14f115 --- /dev/null +++ b/paddle/phi/kernels/funcs/matrix_reduce.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/funcs/matrix_reduce.h" + +#include "paddle/phi/kernels/cpu/reduce.h" +#include "paddle/phi/kernels/funcs/reduce_functor.h" + +namespace phi { +namespace funcs { + +template +class MatrixReduceSumFunctor { + public: + void operator()(const CPUContext& dev_ctx, + const DenseTensor& in, + DenseTensor* out) { + // For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3] + // out_reduce_dim should be [0, 2] + const std::vector in_dims = phi::vectorize(in.dims()); + auto in_size = in_dims.size(); + const std::vector out_dims = phi::vectorize(out->dims()); + auto out_size = out_dims.size(); + + std::vector out_bst_dims(in_size); + + std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1); + std::copy(out_dims.data(), + out_dims.data() + out_size, + out_bst_dims.data() + in_size - out_size); + out->Resize(phi::make_ddim(out_bst_dims)); + + std::vector out_reduce_dims; + for (size_t idx = 0; idx <= in_size - 3; idx++) { + if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) { + out_reduce_dims.push_back(idx); + } + } + phi::ReduceKernelImpl( + dev_ctx, in, out, out_reduce_dims, true, false); + } +}; + +template class MatrixReduceSumFunctor; +template class MatrixReduceSumFunctor; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/matrix_reduce.cu b/paddle/phi/kernels/funcs/matrix_reduce.cu new file mode 100644 index 0000000000000000000000000000000000000000..5e288c6e9c21703471ba7b6a6014510ba845ebd8 --- /dev/null +++ b/paddle/phi/kernels/funcs/matrix_reduce.cu @@ -0,0 +1,62 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/funcs/matrix_reduce.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" + +namespace phi { +namespace funcs { + +template +class MatrixReduceSumFunctor { + public: + void operator()(const GPUContext& dev_ctx, + const DenseTensor& in, + DenseTensor* out) { + // For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3] + // out_reduce_dim should be [0, 2] + const std::vector in_dims = phi::vectorize(in.dims()); + auto in_size = in_dims.size(); + const std::vector out_dims = phi::vectorize(out->dims()); + auto out_size = out_dims.size(); + + std::vector out_bst_dims(in_size); + + std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1); + std::copy(out_dims.data(), + out_dims.data() + out_size, + out_bst_dims.data() + in_size - out_size); + out->Resize(phi::make_ddim(out_bst_dims)); + + std::vector out_reduce_dims; + for (size_t idx = 0; idx <= in_size - 3; idx++) { + if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) { + out_reduce_dims.push_back(idx); + } + } + TensorReduceImpl>( + dev_ctx, + in, + out, + kps::IdentityFunctor(), + out_reduce_dims, + dev_ctx.stream()); + } +}; + +template class MatrixReduceSumFunctor; +template class MatrixReduceSumFunctor; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/matrix_reduce.h b/paddle/phi/kernels/funcs/matrix_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..22bddacd43d437e731ff5baf3e3c18c52dc55fd6 --- /dev/null +++ b/paddle/phi/kernels/funcs/matrix_reduce.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +namespace funcs { + +// Use For Matrix OP, reduce_sum 'in' according to out's dim +// for example: in's dim = [5, 3, 2, M, N] ; out's dim = [3, 1, M, N] +// axis [0, 2] of DenseTensor 'in' will be reduced +template +class MatrixReduceSumFunctor { + public: + void operator()(const Context& dev_ctx, + const DenseTensor& in, + DenseTensor* out); +}; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu b/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..f7eaa485797947ae7b6a60378737d8c955718466 --- /dev/null +++ b/paddle/phi/kernels/gpu/triangular_solve_grad_kernel.cu @@ -0,0 +1,23 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(triangular_solve_grad, + GPU, + ALL_LAYOUT, + phi::TriangularSolveGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/triangular_solve_kernel.cu b/paddle/phi/kernels/gpu/triangular_solve_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..f137d8e1c260387686cbf3d0fbadf686d9e13019 --- /dev/null +++ b/paddle/phi/kernels/gpu/triangular_solve_kernel.cu @@ -0,0 +1,132 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/triangular_solve_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/common_shape.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/memory/memory.h" + +namespace phi { + +template +void TriangularSolveKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool upper, + bool transpose, + bool unitriangular, + DenseTensor* out) { + // get broadcast dim + std::vector x_bst_dims_vec; + std::vector y_bst_dims_vec; + std::tie(x_bst_dims_vec, y_bst_dims_vec) = + funcs::MatrixGetBroadcastDims(x, y); + int x_bst_ndim = x_bst_dims_vec.size(); + int y_bst_ndim = y_bst_dims_vec.size(); + + // Tensor broadcast to 'out' and temp 'x_bst' + ScalarArray x_bst_dims(x_bst_dims_vec); + DenseTensor x_bst = phi::Empty(dev_ctx, x_bst_dims); + const T* x_bst_data = x_bst.data(); + ExpandKernel(dev_ctx, x, x_bst_dims, &x_bst); + + out->Resize(phi::make_ddim(y_bst_dims_vec)); + T* out_data = dev_ctx.template Alloc(out); + ScalarArray y_bst_dims(y_bst_dims_vec); + ExpandKernel(dev_ctx, y, y_bst_dims, out); + + // calculate use cublas library + CBLAS_UPLO uplo = upper ? CblasUpper : CblasLower; + CBLAS_TRANSPOSE transA = transpose ? CblasTrans : CblasNoTrans; + CBLAS_DIAG diag = unitriangular ? CblasUnit : CblasNonUnit; + + int M = static_cast(y_bst_dims_vec[y_bst_ndim - 2]); + int N = static_cast(y_bst_dims_vec[y_bst_ndim - 1]); + auto lda = std::max(1, M); + auto ldb = std::max(1, N); + + int batch_size = 1; + for (int i = 0; i < x_bst_ndim - 2; i++) { + batch_size *= x_bst_dims_vec[i]; + } + + auto blas = phi::funcs::GetBlas(dev_ctx); + if (batch_size <= 8 && M >= 64) { + for (auto i = 0; i < batch_size; i++) { + blas.TRSM(CblasLeft, + uplo, + transA, + diag, + M, + N, + T(1), + x_bst_data + i * M * M, + lda, + out_data + i * N * M, + ldb); + } + } else { + std::vector cpu_ptrs(batch_size * 2); + for (int i = 0; i < batch_size; ++i) { + cpu_ptrs[i] = x_bst_data + i * M * M; + cpu_ptrs[i + batch_size] = out_data + i * M * N; + } + + // Copy the addresses of A and tmp_b from host to device. + paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data = + paddle::memory::Alloc(dev_ctx, cpu_ptrs.size() * sizeof(T*)); + + paddle::memory::Copy(dev_ctx.GetPlace(), + tmp_gpu_ptrs_data->ptr(), + paddle::platform::CPUPlace(), + static_cast(cpu_ptrs.data()), + cpu_ptrs.size() * sizeof(T*), + dev_ctx.stream()); + + const T** gpu_a_ptrs = + reinterpret_cast(tmp_gpu_ptrs_data->ptr()); + T** gpu_b_ptrs = + reinterpret_cast(tmp_gpu_ptrs_data->ptr()) + batch_size; + blas.BatchedTRSM(CblasLeft, + uplo, + transA, + diag, + M, + N, + static_cast(1.0), + gpu_a_ptrs, + lda, + gpu_b_ptrs, + ldb, + batch_size); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(triangular_solve, + GPU, + ALL_LAYOUT, + phi::TriangularSolveKernel, + float, + double) {} diff --git a/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h b/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..a6868ebe6ca51c1e412249695469d6e3ec35363c --- /dev/null +++ b/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h @@ -0,0 +1,138 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/triangular_solve_grad_kernel.h" + +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/matrix_reduce.h" +#include "paddle/phi/kernels/triangular_solve_kernel.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/operators/tril_triu_op.h" + +namespace phi { + +template +void TriangularSolveGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + bool upper, + bool transpose, + bool unitriangular, + DenseTensor* dx, + DenseTensor* dy) { + std::vector x_bst_dims_vec; + std::vector y_bst_dims_vec; + std::tie(x_bst_dims_vec, y_bst_dims_vec) = + funcs::MatrixGetBroadcastDims(x, y); + + ScalarArray y_bst_dims_array(y_bst_dims_vec); + DenseTensor dy_bst = phi::Empty(dev_ctx, y_bst_dims_array); + if (dy) { + // calculate x's conjugate for complex + DenseTensor x_conj = phi::Empty(dev_ctx); + x_conj.Resize(x.dims()); + + phi::funcs::ForRange x_for_range(dev_ctx, x.numel()); + phi::funcs::ConjFunctor x_functor( + x.data(), x.numel(), dev_ctx.template Alloc(&x_conj)); + x_for_range(x_functor); + + // reuse forward to get dy_bst, and the result has been broadcated already. + TriangularSolveKernel( + dev_ctx, x_conj, dout, upper, !transpose, unitriangular, &dy_bst); + + dy->Resize(y.dims()); + dev_ctx.template Alloc(dy); + if (dy_bst.dims() == y.dims()) { + Copy(dev_ctx, dy_bst, dev_ctx.GetPlace(), false, dy); + } else { + funcs::MatrixReduceSumFunctor functor; + functor(dev_ctx, dy_bst, dy); + dy->Resize(y.dims()); + } + } + + ScalarArray x_bst_dims_array(x_bst_dims_vec); + DenseTensor dx_bst = phi::Empty(dev_ctx, x_bst_dims_array); + if (dx) { + // calculate x's conjugate for complex + DenseTensor out_conj = phi::Empty(dev_ctx); + out_conj.Resize(out.dims()); + + phi::funcs::ForRange out_for_range(dev_ctx, out.numel()); + phi::funcs::ConjFunctor out_functor( + out.data(), out.numel(), dev_ctx.template Alloc(&out_conj)); + out_for_range(out_functor); + + auto blas = phi::funcs::GetBlas(dev_ctx); + if (transpose) { + auto mat_dim_a = + phi::funcs::CreateMatrixDescriptor(out_conj.dims(), 0, false); + auto mat_dim_b = + phi::funcs::CreateMatrixDescriptor(dy_bst.dims(), 0, true); + blas.MatMul(out_conj, + mat_dim_a, + dy_bst, + mat_dim_b, + static_cast(-1), + &dx_bst, + static_cast(0)); + } else { + auto mat_dim_a = + phi::funcs::CreateMatrixDescriptor(dy_bst.dims(), 0, false); + auto mat_dim_b = + phi::funcs::CreateMatrixDescriptor(out_conj.dims(), 0, true); + blas.MatMul(dy_bst, + mat_dim_a, + out_conj, + mat_dim_b, + static_cast(-1), + &dx_bst, + static_cast(0)); + } + + // get upper or lower triangular + DenseTensor dx_bst_upper = + phi::Empty(dev_ctx, x_bst_dims_array); + + const auto& dims = dx_bst.dims(); + const auto H = dims[dims.size() - 2]; + const auto W = dims[dims.size() - 1]; + phi::funcs::ForRange x_for_range(dev_ctx, dx_bst.numel()); + paddle::operators::TrilTriuCompute tril_triu_functor( + dx_bst.data(), unitriangular, !upper, H, W, dx_bst_upper.data()); + x_for_range(tril_triu_functor); + + dx->Resize(x.dims()); + dev_ctx.template Alloc(dx); + if (dx_bst.dims() == x.dims()) { + Copy(dev_ctx, dx_bst_upper, dev_ctx.GetPlace(), false, dx); + } else { + funcs::MatrixReduceSumFunctor functor; + functor(dev_ctx, dx_bst_upper, dx); + dx->Resize(x.dims()); + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/triangular_solve_grad_kernel.h b/paddle/phi/kernels/triangular_solve_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..eb5a5ab461a1dcbbdec916dff57e65df5d9cfd9b --- /dev/null +++ b/paddle/phi/kernels/triangular_solve_grad_kernel.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TriangularSolveGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + bool upper, + bool transpose, + bool unitriangular, + DenseTensor* dx, + DenseTensor* dy); + +} // namespace phi diff --git a/paddle/phi/kernels/triangular_solve_kernel.h b/paddle/phi/kernels/triangular_solve_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..833de3f8439ee843577306ff146fa01ad4225390 --- /dev/null +++ b/paddle/phi/kernels/triangular_solve_kernel.h @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TriangularSolveKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool upper, + bool transpose, + bool unitriangular, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/triangular_solve_sig.cc b/paddle/phi/ops/compat/triangular_solve_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..c56af3e21e53e9ded6d01ad7fdb9c0fb5609ea6c --- /dev/null +++ b/paddle/phi/ops/compat/triangular_solve_sig.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature TriangularSolveGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("triangular_solve_grad", + {"X", "Y", "Out", GradVarName("Out")}, + {"upper", "transpose", "unitriangular"}, + {GradVarName("X"), GradVarName("Y")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(triangular_solve_grad, + phi::TriangularSolveGradOpArgumentMapping);