From dce87e3d668235e845918f100057c3e3c17069d5 Mon Sep 17 00:00:00 2001 From: caozhou <48191911+Caozhou1995@users.noreply.github.com> Date: Wed, 16 Mar 2022 14:21:36 +0800 Subject: [PATCH] [Phi] Migrate multiplex, qr, tril_triu op kernel to phi (#40007) * migrate multiplex op kernel * migrate qr cpu kernel * migrate tril_triu op kernel * fix multiplex kernel * add kernel sig * fix dependence and bug * fix multiplex error * fix npu include error * fix conflict * fix conflict and delete tril_triu * fix date and multiplex input * adapt header file order * fix header file include * fix conflict * delete cholesky_solve_op.h * delete triangular_solve_op.h --- paddle/fluid/operators/lu_op.h | 30 ++--- paddle/fluid/operators/lu_unpack_op.h | 9 +- paddle/fluid/operators/multiplex_op.cc | 14 +-- paddle/fluid/operators/multiplex_op.cu | 117 ------------------ paddle/fluid/operators/multiplex_op.h | 96 -------------- paddle/fluid/operators/qr_op.cc | 2 - paddle/fluid/operators/qr_op.h | 79 ------------ paddle/fluid/operators/tril_triu_op.cc | 18 +-- paddle/fluid/operators/tril_triu_op.cu | 35 ------ paddle/fluid/operators/tril_triu_op.h | 102 --------------- paddle/fluid/operators/tril_triu_op_npu.cc | 2 +- paddle/fluid/operators/tril_triu_op_xpu.cc | 2 +- .../phi/kernels/cpu/multiplex_grad_kernel.cc | 65 ++++++++++ paddle/phi/kernels/cpu/multiplex_kernel.cc | 65 ++++++++++ paddle/phi/kernels/cpu/qr_kernel.cc | 116 +++++++++++++++++ .../phi/kernels/cpu/tril_triu_grad_kernel.cc | 29 +++++ paddle/phi/kernels/cpu/tril_triu_kernel.cc | 29 +++++ paddle/phi/kernels/funcs/tril_triu_compute.h | 48 +++++++ .../phi/kernels/gpu/multiplex_grad_kernel.cu | 68 ++++++++++ paddle/phi/kernels/gpu/multiplex_kernel.cu | 70 +++++++++++ .../phi/kernels/gpu/tril_triu_grad_kernel.cu | 29 +++++ paddle/phi/kernels/gpu/tril_triu_kernel.cu | 29 +++++ .../impl/cholesky_solve_grad_kernel_impl.h | 7 +- .../impl/triangular_solve_grad_kernel_impl.h | 7 +- .../kernels/impl/tril_triu_grad_kernel_impl.h | 44 +++++++ .../phi/kernels/impl/tril_triu_kernel_impl.h | 43 +++++++ paddle/phi/kernels/multiplex_grad_kernel.h | 27 ++++ paddle/phi/kernels/multiplex_kernel.h | 27 ++++ paddle/phi/kernels/qr_kernel.h | 28 +++++ paddle/phi/kernels/tril_triu_grad_kernel.h | 28 +++++ paddle/phi/kernels/tril_triu_kernel.h | 28 +++++ paddle/phi/ops/compat/multiplex_sig.cc | 32 +++++ paddle/phi/ops/compat/qr_sig.cc | 25 ++++ paddle/phi/ops/compat/tril_triu_sig.cc | 34 +++++ 34 files changed, 896 insertions(+), 488 deletions(-) delete mode 100644 paddle/fluid/operators/multiplex_op.cu delete mode 100644 paddle/fluid/operators/multiplex_op.h delete mode 100644 paddle/fluid/operators/tril_triu_op.cu delete mode 100644 paddle/fluid/operators/tril_triu_op.h create mode 100644 paddle/phi/kernels/cpu/multiplex_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/multiplex_kernel.cc create mode 100644 paddle/phi/kernels/cpu/qr_kernel.cc create mode 100644 paddle/phi/kernels/cpu/tril_triu_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/tril_triu_kernel.cc create mode 100644 paddle/phi/kernels/funcs/tril_triu_compute.h create mode 100644 paddle/phi/kernels/gpu/multiplex_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/multiplex_kernel.cu create mode 100644 paddle/phi/kernels/gpu/tril_triu_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/tril_triu_kernel.cu create mode 100644 paddle/phi/kernels/impl/tril_triu_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/tril_triu_kernel_impl.h create mode 100644 paddle/phi/kernels/multiplex_grad_kernel.h create mode 100644 paddle/phi/kernels/multiplex_kernel.h create mode 100644 paddle/phi/kernels/qr_kernel.h create mode 100644 paddle/phi/kernels/tril_triu_grad_kernel.h create mode 100644 paddle/phi/kernels/tril_triu_kernel.h create mode 100644 paddle/phi/ops/compat/multiplex_sig.cc create mode 100644 paddle/phi/ops/compat/qr_sig.cc create mode 100644 paddle/phi/ops/compat/tril_triu_sig.cc diff --git a/paddle/fluid/operators/lu_op.h b/paddle/fluid/operators/lu_op.h index 214b2eccae9..6e2ac4617da 100644 --- a/paddle/fluid/operators/lu_op.h +++ b/paddle/fluid/operators/lu_op.h @@ -18,8 +18,8 @@ limitations under the License. */ #include "paddle/fluid/framework/phi_utils.h" #include "paddle/fluid/operators/set_value_op.h" #include "paddle/fluid/operators/svd_helper.h" -#include "paddle/fluid/operators/tril_triu_op.h" #include "paddle/phi/kernels/funcs/lapack/lapack_function.h" +#include "paddle/phi/kernels/funcs/tril_triu_compute.h" #include "paddle/phi/kernels/math_kernel.h" #include "paddle/phi/kernels/triangular_solve_kernel.h" @@ -404,11 +404,12 @@ void LU_Unpack(const DeviceContext& dev_ctx, const framework::Tensor* LU, const auto W = udims[udims.size() - 1]; auto L_dataptr = L->mutable_data(dev_ctx.GetPlace()); platform::ForRange x_for_range(dev_ctx, LU->numel()); - TrilTriuCompute tril_computer(LU->data(), -1, true, H, W, L_dataptr); + phi::funcs::TrilTriuCompute tril_computer(LU->data(), -1, true, H, W, + L_dataptr); x_for_range(tril_computer); - TrilTriuCompute triu_computer(LU->data(), 0, false, H, W, - U->mutable_data(dev_ctx.GetPlace())); + phi::funcs::TrilTriuCompute triu_computer( + LU->data(), 0, false, H, W, U->mutable_data(dev_ctx.GetPlace())); x_for_range(triu_computer); // set L's diagonal 1 @@ -532,15 +533,15 @@ class LUGradKernel : public framework::OpKernel { auto phil_rank = LmHdims.size(); auto phiu_rank = UmHdims.size(); platform::ForRange l_for_range(dev_ctx, phi_L.numel()); - TrilTriuCompute tril_computer(phi_L.data(), -1, true, - LmHdims[phil_rank - 2], - LmHdims[phil_rank - 1], phi_L.data()); + phi::funcs::TrilTriuCompute tril_computer( + phi_L.data(), -1, true, LmHdims[phil_rank - 2], + LmHdims[phil_rank - 1], phi_L.data()); l_for_range(tril_computer); platform::ForRange u_for_range(dev_ctx, phi_U.numel()); - TrilTriuCompute triu_computer(phi_U.data(), 0, false, - UmHdims[phiu_rank - 2], - UmHdims[phiu_rank - 1], phi_U.data()); + phi::funcs::TrilTriuCompute triu_computer( + phi_U.data(), 0, false, UmHdims[phiu_rank - 2], + UmHdims[phiu_rank - 1], phi_U.data()); u_for_range(triu_computer); Tensor_Add(dev_ctx, phi_L, phi_U, &phi); @@ -591,8 +592,9 @@ class LUGradKernel : public framework::OpKernel { const auto W = phidims[phidims.size() - 1]; platform::ForRange x_for_range(dev_ctx, phi_complement.numel()); - TrilTriuCompute tril_computer(phi_complement.data(), -1, true, H, - W, phi_complement_l.data()); + phi::funcs::TrilTriuCompute tril_computer( + phi_complement.data(), -1, true, H, W, + phi_complement_l.data()); x_for_range(tril_computer); Tensor_Sub(dev_ctx, phi, phi_complement_l, &phi); @@ -664,8 +666,8 @@ class LUGradKernel : public framework::OpKernel { const auto W = phidims[phidims.size() - 1]; platform::ForRange x_for_range(dev_ctx, phi_complement.numel()); - TrilTriuCompute triu_computer(phi_complement.data(), 0, false, H, W, - phi_complement_u.data()); + phi::funcs::TrilTriuCompute triu_computer( + phi_complement.data(), 0, false, H, W, phi_complement_u.data()); x_for_range(triu_computer); Tensor_Sub(dev_ctx, phi, phi_complement_u, &phi); diff --git a/paddle/fluid/operators/lu_unpack_op.h b/paddle/fluid/operators/lu_unpack_op.h index d2303f2c08d..e4100867dc6 100644 --- a/paddle/fluid/operators/lu_unpack_op.h +++ b/paddle/fluid/operators/lu_unpack_op.h @@ -16,7 +16,8 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/lu_op.h" -#include "paddle/fluid/operators/tril_triu_op.h" +#include "paddle/fluid/platform/for_range.h" +#include "paddle/phi/kernels/funcs/tril_triu_compute.h" namespace paddle { namespace operators { @@ -87,7 +88,8 @@ class LU_UnpackGradKernel : public framework::OpKernel { auto W = ldims[ldims.size() - 1]; auto L_dataptr = dl_tril.mutable_data(dev_ctx.GetPlace()); platform::ForRange l_for_range(dev_ctx, dl->numel()); - TrilTriuCompute tril_computer(dl->data(), -1, true, H, W, L_dataptr); + phi::funcs::TrilTriuCompute tril_computer(dl->data(), -1, true, H, W, + L_dataptr); l_for_range(tril_computer); const auto udims = du->dims(); @@ -96,7 +98,8 @@ class LU_UnpackGradKernel : public framework::OpKernel { W = udims[udims.size() - 1]; auto U_dataptr = du_triu.mutable_data(dev_ctx.GetPlace()); platform::ForRange u_for_range(dev_ctx, du->numel()); - TrilTriuCompute triu_computer(du->data(), 0, false, H, W, U_dataptr); + phi::funcs::TrilTriuCompute triu_computer(du->data(), 0, false, H, W, + U_dataptr); u_for_range(triu_computer); auto xdims = dx->dims(); diff --git a/paddle/fluid/operators/multiplex_op.cc b/paddle/fluid/operators/multiplex_op.cc index 313a479ea30..8771a6573cb 100644 --- a/paddle/fluid/operators/multiplex_op.cc +++ b/paddle/fluid/operators/multiplex_op.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/multiplex_op.h" #include #include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -169,15 +169,3 @@ REGISTER_OPERATOR(multiplex, ops::MultiplexOp, ops::MultiplexOpMaker, ops::MultiplexGradMaker, ops::MultiplexGradMaker); REGISTER_OPERATOR(multiplex_grad, ops::MultiplexGradOp); -REGISTER_OP_CPU_KERNEL( - multiplex, - ops::MultiplexCPUKernel, - ops::MultiplexCPUKernel, - ops::MultiplexCPUKernel, - ops::MultiplexCPUKernel); -REGISTER_OP_CPU_KERNEL( - multiplex_grad, - ops::MultiplexGradCPUKernel, - ops::MultiplexGradCPUKernel, - ops::MultiplexGradCPUKernel, - ops::MultiplexGradCPUKernel); diff --git a/paddle/fluid/operators/multiplex_op.cu b/paddle/fluid/operators/multiplex_op.cu deleted file mode 100644 index 0a32ee96fb6..00000000000 --- a/paddle/fluid/operators/multiplex_op.cu +++ /dev/null @@ -1,117 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/multiplex_op.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class MultiplexGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto ins = ctx.MultiInput("X"); - auto* ids = ctx.Input("Ids"); - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - for (size_t i = 0; i < ins.size(); ++i) { - PADDLE_ENFORCE_GT( - ins[i]->numel(), 0, - platform::errors::OutOfRange( - "indexing will be out of bounds with size 0 for the %d-th input.", - i)); - } - - auto rows = ins[0]->dims()[0]; - auto cols = ins[0]->numel() / rows; - // copy index to cpu - Tensor index_t_cpu; - paddle::framework::TensorCopySync(*ids, platform::CPUPlace(), &index_t_cpu); - auto* index = index_t_cpu.data(); - auto stream = ctx.cuda_device_context().stream(); - platform::CUDAPlace place = ctx.GetPlace(); - for (auto i = 0; i < rows; i++) { - int32_t k = index[i]; - PADDLE_ENFORCE_GE(k, 0, platform::errors::PreconditionNotMet( - "index must be nonnegative.")); - PADDLE_ENFORCE_LT(static_cast(k), ins.size(), - platform::errors::PreconditionNotMet( - "index exceeds the number of candidate tensors.")); - memory::Copy(place, out->data() + i * cols, place, - ins[k]->data() + i * cols, cols * sizeof(T), stream); - } - } -}; - -template -class MultiplexGradGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto* d_out = ctx.Input(framework::GradVarName("Out")); - auto* ids = ctx.Input("Ids"); - auto d_ins = ctx.MultiOutput(framework::GradVarName("X")); - - size_t idx = -1UL; - for (size_t i = 0; i < d_ins.size(); i++) { - if (d_ins[i]) { - d_ins[i]->mutable_data(ctx.GetPlace()); - auto t = framework::EigenVector::Flatten(*d_ins[i]); - t.device(*ctx.template device_context().eigen_device()) = - t.constant(static_cast(0)); - - idx = i; - } - } - - if (idx == -1UL) return; - - auto rows = d_ins[idx]->dims()[0]; - auto cols = d_ins[idx]->numel() / rows; - // copy index to cpu - Tensor index_t_cpu; - paddle::framework::TensorCopySync(*ids, platform::CPUPlace(), &index_t_cpu); - auto* index = index_t_cpu.data(); - - auto stream = ctx.cuda_device_context().stream(); - platform::CUDAPlace place = ctx.GetPlace(); - for (auto i = 0; i < rows; i++) { - size_t k = static_cast(index[i]); - if (d_ins[k]) { - memory::Copy(place, d_ins[k]->data() + i * cols, place, - d_out->data() + i * cols, cols * sizeof(T), stream); - } - } - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_CUDA_KERNEL( - multiplex, - ops::MultiplexGPUKernel, - ops::MultiplexGPUKernel, - ops::MultiplexGPUKernel, - ops::MultiplexGPUKernel); -REGISTER_OP_CUDA_KERNEL( - multiplex_grad, - ops::MultiplexGradGPUKernel, - ops::MultiplexGradGPUKernel, - ops::MultiplexGradGPUKernel, - ops::MultiplexGradGPUKernel); diff --git a/paddle/fluid/operators/multiplex_op.h b/paddle/fluid/operators/multiplex_op.h deleted file mode 100644 index 1d0a009edee..00000000000 --- a/paddle/fluid/operators/multiplex_op.h +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/memory/memcpy.h" - -namespace paddle { -namespace operators { - -template -class MultiplexCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto ins = ctx.MultiInput("X"); - auto ids = ctx.Input("Ids"); - auto* out = ctx.Output("Out"); - - out->mutable_data(ctx.GetPlace()); - - for (size_t i = 0; i < ins.size(); ++i) { - PADDLE_ENFORCE_GT( - ins[i]->numel(), 0, - platform::errors::OutOfRange( - "indexing will be out of bounds with size 0 for the %d-th input.", - i)); - } - - auto rows = ins[0]->dims()[0]; - auto cols = ins[0]->numel() / rows; - auto index = ids->data(); - platform::CPUPlace place = ctx.GetPlace(); - for (auto i = 0; i < rows; i++) { - int32_t k = index[i]; - PADDLE_ENFORCE_GE(k, 0, platform::errors::PreconditionNotMet( - "index must be nonnegative.")); - PADDLE_ENFORCE_LT(static_cast(k), ins.size(), - platform::errors::PreconditionNotMet( - "index exceeds the number of candidate tensors.")); - memory::Copy(place, out->data() + i * cols, place, - ins[k]->data() + i * cols, cols * sizeof(T)); - } - } -}; - -template -class MultiplexGradCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto* d_out = ctx.Input(framework::GradVarName("Out")); - auto* ids = ctx.Input("Ids"); - auto d_ins = - ctx.MultiOutput(framework::GradVarName("X")); - - size_t idx = -1UL; - for (size_t i = 0; i < d_ins.size(); i++) { - if (d_ins[i]) { - d_ins[i]->mutable_data(ctx.GetPlace()); - auto t = framework::EigenVector::Flatten(*d_ins[i]); - t.device(*ctx.template device_context().eigen_device()) = - t.constant(static_cast(0)); - - idx = i; - } - } - - if (idx == -1UL) return; - - auto rows = d_ins[idx]->dims()[0]; - auto cols = d_ins[idx]->numel() / rows; - auto* index = ids->data(); - platform::CPUPlace place = ctx.GetPlace(); - for (auto i = 0; i < rows; i++) { - size_t k = static_cast(index[i]); - if (d_ins[k]) { - memory::Copy(place, d_ins[k]->data() + i * cols, place, - d_out->data() + i * cols, cols * sizeof(T)); - } - } - } -}; -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/qr_op.cc b/paddle/fluid/operators/qr_op.cc index 40e3cbde3b0..82fc9ef1b78 100644 --- a/paddle/fluid/operators/qr_op.cc +++ b/paddle/fluid/operators/qr_op.cc @@ -145,8 +145,6 @@ REGISTER_OPERATOR(qr, ops::QrOp, ops::QrOpMaker, REGISTER_OPERATOR(qr_grad, ops::QrGradOp); -REGISTER_OP_CPU_KERNEL(qr, ops::QrCPUKernel, ops::QrCPUKernel); - REGISTER_OP_CPU_KERNEL( qr_grad, ops::QrGradKernel, ops::QrGradKernel); diff --git a/paddle/fluid/operators/qr_op.h b/paddle/fluid/operators/qr_op.h index f09a07e96cd..5ef02d89427 100644 --- a/paddle/fluid/operators/qr_op.h +++ b/paddle/fluid/operators/qr_op.h @@ -48,85 +48,6 @@ static inline std::tuple _parse_qr_mode(std::string mode) { return std::make_tuple(compute_q, reduced); } -template -class QrCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - bool compute_q; - bool reduced_mode; - const Tensor& x = *context.Input("X"); - Tensor& q = *context.Output("Q"); - Tensor& r = *context.Output("R"); - std::string mode = context.Attr("mode"); - std::tie(compute_q, reduced_mode) = _parse_qr_mode(mode); - - auto numel = x.numel(); - PADDLE_ENFORCE_GT(numel, 0, platform::errors::PreconditionNotMet( - "The input of QR is empty.")); - auto x_dims = x.dims(); - int x_rank = x_dims.size(); - int m = x_dims[x_rank - 2]; - int n = x_dims[x_rank - 1]; - int min_mn = std::min(m, n); - int k = reduced_mode ? min_mn : m; - int batch_size = numel / (m * n); - int x_stride = m * n; - int q_stride = m * k; - int r_stride = k * n; - - auto* x_data = x.data>(); - T* q_data = nullptr; - if (compute_q) { - q_data = q.mutable_data>( - context.GetPlace(), - size_t(batch_size * m * k * sizeof(phi::dtype::Real))); - memset(q_data, 0, - size_t(batch_size * m * k * sizeof(phi::dtype::Real))); - } - auto* r_data = r.mutable_data>( - context.GetPlace(), - size_t(batch_size * k * n * sizeof(phi::dtype::Real))); - memset(r_data, 0, size_t(batch_size * k * n * sizeof(phi::dtype::Real))); - - // Implement QR by calling Eigen - for (int i = 0; i < batch_size; ++i) { - const T* x_matrix_ptr = x_data + i * x_stride; - T* r_matrix_ptr = r_data + i * r_stride; - using EigenDynamicMatrix = - Eigen::Matrix; - auto x_matrix = Eigen::Map(x_matrix_ptr, m, n); - Eigen::HouseholderQR qr(x_matrix); - if (reduced_mode) { - auto qr_top_matrix = qr.matrixQR().block(0, 0, min_mn, n); - auto r_matrix_view = - qr_top_matrix.template triangularView(); - auto r_matrix = EigenDynamicMatrix(r_matrix_view); - memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T)); - } else { - auto r_matrix_view = - qr.matrixQR().template triangularView(); - auto r_matrix = EigenDynamicMatrix(r_matrix_view); - memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T)); - } - - if (compute_q) { - T* q_matrix_ptr = q_data + i * q_stride; - if (reduced_mode) { - auto q_matrix = - qr.householderQ() * EigenDynamicMatrix::Identity(m, min_mn); - q_matrix.transposeInPlace(); - memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T)); - } else { - auto q_matrix = - qr.householderQ() * EigenDynamicMatrix::Identity(m, m); - q_matrix.transposeInPlace(); - memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T)); - } - } - } - } -}; - template class QrGradKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/tril_triu_op.cc b/paddle/fluid/operators/tril_triu_op.cc index 3e943c62e1c..c8010e8a128 100644 --- a/paddle/fluid/operators/tril_triu_op.cc +++ b/paddle/fluid/operators/tril_triu_op.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/tril_triu_op.h" #include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -104,19 +104,3 @@ REGISTER_OPERATOR(tril_triu, ops::TrilTriuOp, ops::TrilTriuOpMaker, ops::TrilTriuGradOpMaker, ops::TrilTriuGradOpMaker); REGISTER_OPERATOR(tril_triu_grad, ops::TrilTriuGradOp); -REGISTER_OP_CPU_KERNEL( - tril_triu, ops::TrilTriuOpKernel, - ops::TrilTriuOpKernel, - ops::TrilTriuOpKernel, - ops::TrilTriuOpKernel, - ops::TrilTriuOpKernel, - ops::TrilTriuOpKernel); -REGISTER_OP_CPU_KERNEL( - tril_triu_grad, - ops::TrilTriuGradOpKernel, - ops::TrilTriuGradOpKernel, - ops::TrilTriuGradOpKernel, - ops::TrilTriuGradOpKernel, - ops::TrilTriuGradOpKernel, - ops::TrilTriuGradOpKernel); diff --git a/paddle/fluid/operators/tril_triu_op.cu b/paddle/fluid/operators/tril_triu_op.cu deleted file mode 100644 index 9cbbdeeb2ce..00000000000 --- a/paddle/fluid/operators/tril_triu_op.cu +++ /dev/null @@ -1,35 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/tril_triu_op.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - tril_triu, ops::TrilTriuOpKernel, - ops::TrilTriuOpKernel, - ops::TrilTriuOpKernel, - ops::TrilTriuOpKernel, - ops::TrilTriuOpKernel, - ops::TrilTriuOpKernel); -REGISTER_OP_CUDA_KERNEL( - tril_triu_grad, - ops::TrilTriuGradOpKernel, - ops::TrilTriuGradOpKernel, - ops::TrilTriuGradOpKernel, - ops::TrilTriuGradOpKernel, - ops::TrilTriuGradOpKernel, - ops::TrilTriuGradOpKernel); diff --git a/paddle/fluid/operators/tril_triu_op.h b/paddle/fluid/operators/tril_triu_op.h deleted file mode 100644 index 3150b7617d1..00000000000 --- a/paddle/fluid/operators/tril_triu_op.h +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/fluid/platform/for_range.h" - -namespace paddle { -namespace operators { - -template -class TrilTriuCompute { - public: - HOSTDEVICE TrilTriuCompute(const T* in, const int diagonal, const bool lower, - const int64_t H, const int64_t W, T* out) - : in_(in), diagonal_(diagonal), lower_(lower), H_(H), W_(W), out_(out) {} - - HOSTDEVICE void operator()(int64_t idx) { - const int64_t row = (idx / W_) % H_; - const int64_t col = idx % W_; - const bool mask = - lower_ ? (col - row > diagonal_) : (col - row < diagonal_); - out_[idx] = mask ? static_cast(0) : in_[idx]; - } - - private: - const T* in_; - const int diagonal_; - const bool lower_; - const int64_t H_; - const int64_t W_; - T* out_; -}; - -template -class TrilTriuOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const auto* x = context.Input("X"); - const auto* x_data = x->data(); - auto* out = context.Output("Out"); - auto* out_data = out->mutable_data(context.GetPlace()); - - const int diagonal = context.Attr("diagonal"); - const bool lower = context.Attr("lower"); - - const auto& dims = x->dims(); - const auto H = dims[dims.size() - 2]; - const auto W = dims[dims.size() - 1]; - - platform::ForRange for_range( - context.template device_context(), - static_cast(x->numel())); - - paddle::operators::TrilTriuCompute tril_triu_computer( - x_data, diagonal, lower, H, W, out_data); - for_range(tril_triu_computer); - } -}; - -template -class TrilTriuGradOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const auto* d_out = - context.Input(framework::GradVarName("Out")); - const auto* dout_data = d_out->data(); - auto* d_x = context.Output(framework::GradVarName("X")); - auto* dx_data = d_x->mutable_data(context.GetPlace()); - - const int diagonal = context.Attr("diagonal"); - const bool lower = context.Attr("lower"); - - const auto& dims = d_out->dims(); - const auto H = dims[dims.size() - 2]; - const auto W = dims[dims.size() - 1]; - - platform::ForRange for_range( - context.template device_context(), - static_cast(d_out->numel())); - - paddle::operators::TrilTriuCompute tril_triu_grad_computer( - dout_data, diagonal, lower, H, W, dx_data); - for_range(tril_triu_grad_computer); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/tril_triu_op_npu.cc b/paddle/fluid/operators/tril_triu_op_npu.cc index ad1c1814c05..4145730357d 100644 --- a/paddle/fluid/operators/tril_triu_op_npu.cc +++ b/paddle/fluid/operators/tril_triu_op_npu.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/tril_triu_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/fluid/operators/tril_triu_op_xpu.cc b/paddle/fluid/operators/tril_triu_op_xpu.cc index e36cbcf228c..a44ea8ff689 100644 --- a/paddle/fluid/operators/tril_triu_op_xpu.cc +++ b/paddle/fluid/operators/tril_triu_op_xpu.cc @@ -11,7 +11,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU -#include "paddle/fluid/operators/tril_triu_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/device_wrapper.h" namespace paddle { diff --git a/paddle/phi/kernels/cpu/multiplex_grad_kernel.cc b/paddle/phi/kernels/cpu/multiplex_grad_kernel.cc new file mode 100644 index 00000000000..f5a426e93db --- /dev/null +++ b/paddle/phi/kernels/cpu/multiplex_grad_kernel.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/multiplex_grad_kernel.h" + +#include "paddle/fluid/memory/memcpy.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void MultiplexGradKernel(const Context& ctx, + const DenseTensor& ids, + const DenseTensor& out_grad, + std::vector ins_grad) { + size_t idx = -1UL; + for (size_t i = 0; i < ins_grad.size(); i++) { + if (ins_grad[i]) { + ctx.template Alloc(ins_grad[i]); + auto t = phi::EigenVector::Flatten(*ins_grad[i]); + t.device(*ctx.eigen_device()) = t.constant(static_cast(0)); + idx = i; + } + } + if (idx == -1UL) return; + + auto rows = ins_grad[idx]->dims()[0]; + auto cols = ins_grad[idx]->numel() / rows; + auto* index = ids.data(); + for (auto i = 0; i < rows; i++) { + size_t k = static_cast(index[i]); + if (ins_grad[k]) { + paddle::memory::Copy(ctx.GetPlace(), + ins_grad[k]->data() + i * cols, + ctx.GetPlace(), + out_grad.data() + i * cols, + cols * sizeof(T)); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(multiplex_grad, + CPU, + ALL_LAYOUT, + phi::MultiplexGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/multiplex_kernel.cc b/paddle/phi/kernels/cpu/multiplex_kernel.cc new file mode 100644 index 00000000000..2d9f4c51a98 --- /dev/null +++ b/paddle/phi/kernels/cpu/multiplex_kernel.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/multiplex_kernel.h" + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void MultiplexKernel(const Context& ctx, + const std::vector& ins, + const DenseTensor& ids, + DenseTensor* out) { + ctx.template Alloc(out); + for (size_t i = 0; i < ins.size(); ++i) { + PADDLE_ENFORCE_GT( + ins[i]->numel(), + 0, + errors::OutOfRange( + "indexing will be out of bounds with size 0 for the %d-th input.", + i)); + } + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->numel() / rows; + auto index = ids.data(); + for (auto i = 0; i < rows; i++) { + int32_t k = index[i]; + PADDLE_ENFORCE_GE( + k, 0, errors::PreconditionNotMet("index must be nonnegative.")); + PADDLE_ENFORCE_LT(static_cast(k), + ins.size(), + errors::PreconditionNotMet( + "index exceeds the number of candidate tensors.")); + paddle::memory::Copy(ctx.GetPlace(), + out->data() + i * cols, + ctx.GetPlace(), + ins[k]->data() + i * cols, + cols * sizeof(T)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(multiplex, + CPU, + ALL_LAYOUT, + phi::MultiplexKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/cpu/qr_kernel.cc b/paddle/phi/kernels/cpu/qr_kernel.cc new file mode 100644 index 00000000000..e2e32567441 --- /dev/null +++ b/paddle/phi/kernels/cpu/qr_kernel.cc @@ -0,0 +1,116 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/phi/kernels/qr_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" + +namespace phi { + +static inline std::tuple ParseQrMode(const std::string& mode) { + bool compute_q; + bool reduced; + if (mode == "reduced") { + compute_q = true; + reduced = true; + } else if (mode == "complete") { + compute_q = true; + reduced = false; + } else if (mode == "r") { + compute_q = false; + reduced = true; + } else { + PADDLE_THROW(errors::InvalidArgument( + "QR received unrecognized mode '%s'" + " but expected one of 'reduced' (default), 'r', or 'complete'", + mode)); + } + return std::make_tuple(compute_q, reduced); +} + +template +void QrKernel(const Context& ctx, + const DenseTensor& x, + const std::string& mode, + DenseTensor* q, + DenseTensor* r) { + bool compute_q; + bool reduced_mode; + std::tie(compute_q, reduced_mode) = ParseQrMode(mode); + auto numel = x.numel(); + PADDLE_ENFORCE_GT( + numel, 0, errors::PreconditionNotMet("The input of QR is empty.")); + auto x_dims = x.dims(); + int x_rank = x_dims.size(); + int m = x_dims[x_rank - 2]; + int n = x_dims[x_rank - 1]; + int min_mn = std::min(m, n); + int k = reduced_mode ? min_mn : m; + int batch_size = numel / (m * n); + int x_stride = m * n; + int q_stride = m * k; + int r_stride = k * n; + auto* x_data = x.data>(); + T* q_data = nullptr; + if (compute_q) { + q_data = ctx.template Alloc>( + q, batch_size * m * k * sizeof(phi::dtype::Real)); + } + auto* r_data = ctx.template Alloc>( + r, batch_size * k * n * sizeof(phi::dtype::Real)); + + // Implement QR by calling Eigen + for (int i = 0; i < batch_size; ++i) { + const T* x_matrix_ptr = x_data + i * x_stride; + T* r_matrix_ptr = r_data + i * r_stride; + using EigenDynamicMatrix = + Eigen::Matrix; + auto x_matrix = Eigen::Map(x_matrix_ptr, m, n); + Eigen::HouseholderQR qr(x_matrix); + if (reduced_mode) { + auto qr_top_matrix = qr.matrixQR().block(0, 0, min_mn, n); + auto r_matrix_view = + qr_top_matrix.template triangularView(); + auto r_matrix = EigenDynamicMatrix(r_matrix_view); + memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T)); + } else { + auto r_matrix_view = + qr.matrixQR().template triangularView(); + auto r_matrix = EigenDynamicMatrix(r_matrix_view); + memcpy(r_matrix_ptr, r_matrix.data(), r_matrix.size() * sizeof(T)); + } + + if (compute_q) { + T* q_matrix_ptr = q_data + i * q_stride; + if (reduced_mode) { + auto q_matrix = + qr.householderQ() * EigenDynamicMatrix::Identity(m, min_mn); + q_matrix.transposeInPlace(); + memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T)); + } else { + auto q_matrix = qr.householderQ() * EigenDynamicMatrix::Identity(m, m); + q_matrix.transposeInPlace(); + memcpy(q_matrix_ptr, q_matrix.data(), q_matrix.size() * sizeof(T)); + } + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(qr, CPU, ALL_LAYOUT, phi::QrKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/tril_triu_grad_kernel.cc b/paddle/phi/kernels/cpu/tril_triu_grad_kernel.cc new file mode 100644 index 00000000000..14aca258a2c --- /dev/null +++ b/paddle/phi/kernels/cpu/tril_triu_grad_kernel.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/tril_triu_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(tril_triu_grad, + CPU, + ALL_LAYOUT, + phi::TrilTriuGradKernel, + bool, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/cpu/tril_triu_kernel.cc b/paddle/phi/kernels/cpu/tril_triu_kernel.cc new file mode 100644 index 00000000000..a3d20e55e21 --- /dev/null +++ b/paddle/phi/kernels/cpu/tril_triu_kernel.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(tril_triu, + CPU, + ALL_LAYOUT, + phi::TrilTriuKernel, + bool, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/funcs/tril_triu_compute.h b/paddle/phi/kernels/funcs/tril_triu_compute.h new file mode 100644 index 00000000000..d2b6f1e559d --- /dev/null +++ b/paddle/phi/kernels/funcs/tril_triu_compute.h @@ -0,0 +1,48 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace phi { +namespace funcs { + +template +class TrilTriuCompute { + public: + HOSTDEVICE TrilTriuCompute(const T* in, + const int diagonal, + const bool lower, + const int64_t H, + const int64_t W, + T* out) + : in_(in), diagonal_(diagonal), lower_(lower), H_(H), W_(W), out_(out) {} + + HOSTDEVICE void operator()(int64_t idx) { + const int64_t row = (idx / W_) % H_; + const int64_t col = idx % W_; + const bool mask = + lower_ ? (col - row > diagonal_) : (col - row < diagonal_); + out_[idx] = mask ? static_cast(0) : in_[idx]; + } + + private: + const T* in_; + const int diagonal_; + const bool lower_; + const int64_t H_; + const int64_t W_; + T* out_; +}; +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/multiplex_grad_kernel.cu b/paddle/phi/kernels/gpu/multiplex_grad_kernel.cu new file mode 100644 index 00000000000..21576ab608d --- /dev/null +++ b/paddle/phi/kernels/gpu/multiplex_grad_kernel.cu @@ -0,0 +1,68 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/multiplex_grad_kernel.h" + +#include "paddle/phi/api/lib/utils/tensor_utils.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void MultiplexGradKernel(const Context& ctx, + const DenseTensor& ids, + const DenseTensor& out_grad, + std::vector ins_grad) { + size_t idx = -1UL; + for (size_t i = 0; i < ins_grad.size(); i++) { + if (ins_grad[i]) { + ctx.template Alloc(ins_grad[i]); + auto t = phi::EigenVector::Flatten(*ins_grad[i]); + t.device(*ctx.eigen_device()) = t.constant(static_cast(0)); + idx = i; + } + } + if (idx == -1UL) return; + + auto rows = ins_grad[idx]->dims()[0]; + auto cols = ins_grad[idx]->numel() / rows; + DenseTensor index_t_cpu; + paddle::framework::TensorCopySync(ids, phi::CPUPlace(), &index_t_cpu); + auto* index = index_t_cpu.data(); + auto stream = ctx.stream(); + for (auto i = 0; i < rows; i++) { + size_t k = static_cast(index[i]); + if (ins_grad[k]) { + paddle::memory::Copy(ctx.GetPlace(), + ins_grad[k]->data() + i * cols, + ctx.GetPlace(), + out_grad.data() + i * cols, + cols * sizeof(T), + stream); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(multiplex_grad, + GPU, + ALL_LAYOUT, + phi::MultiplexGradKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/multiplex_kernel.cu b/paddle/phi/kernels/gpu/multiplex_kernel.cu new file mode 100644 index 00000000000..743448a4686 --- /dev/null +++ b/paddle/phi/kernels/gpu/multiplex_kernel.cu @@ -0,0 +1,70 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/multiplex_kernel.h" + +#include "paddle/phi/api/lib/utils/tensor_utils.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void MultiplexKernel(const Context& ctx, + const std::vector& ins, + const DenseTensor& ids, + DenseTensor* out) { + ctx.template Alloc(out); + for (size_t i = 0; i < ins.size(); ++i) { + PADDLE_ENFORCE_GT( + ins[i]->numel(), + 0, + errors::OutOfRange( + "indexing will be out of bounds with size 0 for the %d-th input.", + i)); + } + + auto rows = ins[0]->dims()[0]; + auto cols = ins[0]->numel() / rows; + DenseTensor index_t_cpu; + paddle::framework::TensorCopySync(ids, phi::CPUPlace(), &index_t_cpu); + auto* index = index_t_cpu.data(); + auto stream = ctx.stream(); + for (auto i = 0; i < rows; i++) { + int32_t k = index[i]; + PADDLE_ENFORCE_GE( + k, 0, errors::PreconditionNotMet("index must be nonnegative.")); + PADDLE_ENFORCE_LT(static_cast(k), + ins.size(), + errors::PreconditionNotMet( + "index exceeds the number of candidate tensors.")); + paddle::memory::Copy(ctx.GetPlace(), + out->data() + i * cols, + ctx.GetPlace(), + ins[k]->data() + i * cols, + cols * sizeof(T), + stream); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(multiplex, + GPU, + ALL_LAYOUT, + phi::MultiplexKernel, + float, + double, + int, + int64_t) {} diff --git a/paddle/phi/kernels/gpu/tril_triu_grad_kernel.cu b/paddle/phi/kernels/gpu/tril_triu_grad_kernel.cu new file mode 100644 index 00000000000..bc3ef1bc623 --- /dev/null +++ b/paddle/phi/kernels/gpu/tril_triu_grad_kernel.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/tril_triu_grad_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(tril_triu_grad, + GPU, + ALL_LAYOUT, + phi::TrilTriuGradKernel, + bool, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/tril_triu_kernel.cu b/paddle/phi/kernels/gpu/tril_triu_kernel.cu new file mode 100644 index 00000000000..8c48edf9eff --- /dev/null +++ b/paddle/phi/kernels/gpu/tril_triu_kernel.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/tril_triu_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(tril_triu, + GPU, + ALL_LAYOUT, + phi::TrilTriuKernel, + bool, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h b/paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h index 9f557e74637..72741e6d3a0 100644 --- a/paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h @@ -24,13 +24,12 @@ #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/matrix_reduce.h" +#include "paddle/phi/kernels/funcs/tril_triu_compute.h" #include "paddle/phi/kernels/math_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h" -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/operators/tril_triu_op.h" - namespace phi { template @@ -115,7 +114,7 @@ void CholeskySolveGradKernel(const Context& dev_ctx, const auto H = y_bst_dims_vec[y_bst_ndim - 2]; const auto W = y_bst_dims_vec[y_bst_ndim - 1]; phi::funcs::ForRange y_for_range(dev_ctx, dy_bst.numel()); - paddle::operators::TrilTriuCompute tril_triu_functor( + phi::funcs::TrilTriuCompute tril_triu_functor( dy_bst.data(), 0, !upper, H, W, dy_bst_upper.data()); y_for_range(tril_triu_functor); diff --git a/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h b/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h index 9b1e4b1d3a6..044adb0230c 100644 --- a/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/triangular_solve_grad_kernel_impl.h @@ -21,12 +21,11 @@ #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/matrix_reduce.h" +#include "paddle/phi/kernels/funcs/tril_triu_compute.h" #include "paddle/phi/kernels/triangular_solve_kernel.h" -// See Note [ Why still include the fluid headers? ] -#include "paddle/fluid/operators/tril_triu_op.h" - namespace phi { template @@ -119,7 +118,7 @@ void TriangularSolveGradKernel(const Context& dev_ctx, const auto H = dims[dims.size() - 2]; const auto W = dims[dims.size() - 1]; phi::funcs::ForRange x_for_range(dev_ctx, dx_bst.numel()); - paddle::operators::TrilTriuCompute tril_triu_functor( + phi::funcs::TrilTriuCompute tril_triu_functor( dx_bst.data(), unitriangular, !upper, H, W, dx_bst_upper.data()); x_for_range(tril_triu_functor); diff --git a/paddle/phi/kernels/impl/tril_triu_grad_kernel_impl.h b/paddle/phi/kernels/impl/tril_triu_grad_kernel_impl.h new file mode 100644 index 00000000000..dcc7224b507 --- /dev/null +++ b/paddle/phi/kernels/impl/tril_triu_grad_kernel_impl.h @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/tril_triu_grad_kernel.h" + +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/tril_triu_compute.h" + +namespace phi { + +template +void TrilTriuGradKernel(const Context& ctx, + const DenseTensor& out_grad, + int diagonal, + bool lower, + DenseTensor* x_grad) { + const auto* dout_data = out_grad.data(); + auto* dx_data = ctx.template Alloc(x_grad); + + const auto& dims = out_grad.dims(); + const auto H = dims[dims.size() - 2]; + const auto W = dims[dims.size() - 1]; + + phi::funcs::ForRange for_range( + ctx, static_cast(out_grad.numel())); + phi::funcs::TrilTriuCompute tril_triu_grad_computer( + dout_data, diagonal, lower, H, W, dx_data); + for_range(tril_triu_grad_computer); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/tril_triu_kernel_impl.h b/paddle/phi/kernels/impl/tril_triu_kernel_impl.h new file mode 100644 index 00000000000..959169d87ce --- /dev/null +++ b/paddle/phi/kernels/impl/tril_triu_kernel_impl.h @@ -0,0 +1,43 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/tril_triu_kernel.h" + +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/funcs/tril_triu_compute.h" + +namespace phi { + +template +void TrilTriuKernel(const Context& ctx, + const DenseTensor& x, + int diagonal, + bool lower, + DenseTensor* out) { + const auto* x_data = x.data(); + auto* out_data = ctx.template Alloc(out); + + const auto& dims = x.dims(); + const auto H = dims[dims.size() - 2]; + const auto W = dims[dims.size() - 1]; + phi::funcs::ForRange for_range(ctx, static_cast(x.numel())); + + phi::funcs::TrilTriuCompute tril_triu_computer( + x_data, diagonal, lower, H, W, out_data); + for_range(tril_triu_computer); +} + +} // namespace phi diff --git a/paddle/phi/kernels/multiplex_grad_kernel.h b/paddle/phi/kernels/multiplex_grad_kernel.h new file mode 100644 index 00000000000..b32c9dbe100 --- /dev/null +++ b/paddle/phi/kernels/multiplex_grad_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MultiplexGradKernel(const Context& ctx, + const DenseTensor& ids, + const DenseTensor& out_grad, + std::vector ins_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/multiplex_kernel.h b/paddle/phi/kernels/multiplex_kernel.h new file mode 100644 index 00000000000..341c6d5cabb --- /dev/null +++ b/paddle/phi/kernels/multiplex_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void MultiplexKernel(const Context& ctx, + const std::vector& ins, + const DenseTensor& ids, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/qr_kernel.h b/paddle/phi/kernels/qr_kernel.h new file mode 100644 index 00000000000..9c3dfb16601 --- /dev/null +++ b/paddle/phi/kernels/qr_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void QrKernel(const Context& ctx, + const DenseTensor& x, + const std::string& mode, + DenseTensor* q, + DenseTensor* r); + +} // namespace phi diff --git a/paddle/phi/kernels/tril_triu_grad_kernel.h b/paddle/phi/kernels/tril_triu_grad_kernel.h new file mode 100644 index 00000000000..10faf5c48d5 --- /dev/null +++ b/paddle/phi/kernels/tril_triu_grad_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TrilTriuGradKernel(const Context& ctx, + const DenseTensor& out_grad, + int diagonal, + bool lower, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/tril_triu_kernel.h b/paddle/phi/kernels/tril_triu_kernel.h new file mode 100644 index 00000000000..4daa84e25c3 --- /dev/null +++ b/paddle/phi/kernels/tril_triu_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void TrilTriuKernel(const Context& ctx, + const DenseTensor& x, + int diagonal, + bool lower, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/multiplex_sig.cc b/paddle/phi/ops/compat/multiplex_sig.cc new file mode 100644 index 00000000000..9dab4655d17 --- /dev/null +++ b/paddle/phi/ops/compat/multiplex_sig.cc @@ -0,0 +1,32 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature MultiplexOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("multiplex", {"X", "Ids"}, {}, {"Out"}); +} + +KernelSignature MultiplexGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "multiplex_grad", {"Ids", GradVarName("Out")}, {}, {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(multiplex, phi::MultiplexOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(multiplex_grad, phi::MultiplexGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/qr_sig.cc b/paddle/phi/ops/compat/qr_sig.cc new file mode 100644 index 00000000000..dd424d590ee --- /dev/null +++ b/paddle/phi/ops/compat/qr_sig.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature QrOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("qr", {"X"}, {"mode"}, {"Q", "R"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(qr, phi::QrOpArgumentMapping); diff --git a/paddle/phi/ops/compat/tril_triu_sig.cc b/paddle/phi/ops/compat/tril_triu_sig.cc new file mode 100644 index 00000000000..4f79f8650de --- /dev/null +++ b/paddle/phi/ops/compat/tril_triu_sig.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature TrilTriuOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("tril_triu", {"X"}, {"diagonal", "lower"}, {"Out"}); +} + +KernelSignature TrilTriuGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("tril_triu_grad", + {GradVarName("Out")}, + {"diagonal", "lower"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(tril_triu, phi::TrilTriuOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(tril_triu_grad, phi::TrilTriuGradOpArgumentMapping); -- GitLab