From f181d47f250bd801111419bdac6b6abb806a208b Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Tue, 15 Mar 2022 08:43:29 +0800 Subject: [PATCH] [Phi]Move kron kernel to phi (#40427) * first commit * fix * fix * fix compile eeror * fix * fix complex * fix * fix * fix npu * fix * modify accroding to comments * fix --- paddle/fluid/operators/gather_op_npu.cc | 1 - paddle/fluid/operators/kron_op.cc | 27 +- paddle/fluid/operators/kron_op.cu | 42 -- paddle/fluid/operators/kron_op.h | 415 ------------------ paddle/fluid/operators/scatter_op_npu.cc | 2 +- paddle/phi/kernels/cpu/kron_grad_kernel.cc | 31 ++ paddle/phi/kernels/cpu/kron_kernel.cc | 31 ++ paddle/phi/kernels/gpu/kron_grad_kernel.cu | 31 ++ paddle/phi/kernels/gpu/kron_kernel.cu | 31 ++ .../phi/kernels/impl/kron_grad_kernel_impl.h | 295 +++++++++++++ paddle/phi/kernels/impl/kron_kernel_impl.h | 167 +++++++ paddle/phi/kernels/kron_grad_kernel.h | 29 ++ paddle/phi/kernels/kron_kernel.h | 27 ++ paddle/phi/ops/compat/kron_sig.cc | 28 ++ 14 files changed, 672 insertions(+), 485 deletions(-) delete mode 100644 paddle/fluid/operators/kron_op.cu delete mode 100644 paddle/fluid/operators/kron_op.h create mode 100644 paddle/phi/kernels/cpu/kron_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/kron_kernel.cc create mode 100644 paddle/phi/kernels/gpu/kron_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/kron_kernel.cu create mode 100644 paddle/phi/kernels/impl/kron_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/kron_kernel_impl.h create mode 100644 paddle/phi/kernels/kron_grad_kernel.h create mode 100644 paddle/phi/kernels/kron_kernel.h create mode 100644 paddle/phi/ops/compat/kron_sig.cc diff --git a/paddle/fluid/operators/gather_op_npu.cc b/paddle/fluid/operators/gather_op_npu.cc index a83abb24522..21093f585b5 100644 --- a/paddle/fluid/operators/gather_op_npu.cc +++ b/paddle/fluid/operators/gather_op_npu.cc @@ -17,7 +17,6 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/kron_op.h" #include "paddle/fluid/platform/device/npu/npu_info.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" diff --git a/paddle/fluid/operators/kron_op.cc b/paddle/fluid/operators/kron_op.cc index 58d51ab1c72..68d0c7978b4 100644 --- a/paddle/fluid/operators/kron_op.cc +++ b/paddle/fluid/operators/kron_op.cc @@ -17,9 +17,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/kron_op.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -178,27 +176,4 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(kron, ops::KronOp, ops::KronOpMaker, ops::KronGradOpMaker, ops::KronGradOpMaker); -REGISTER_OP_CPU_KERNEL( - kron, ops::KronKernel, - ops::KronKernel, - ops::KronKernel, - ops::KronKernel, - ops::KronKernel, - ops::KronKernel>, - ops::KronKernel>); - REGISTER_OPERATOR(kron_grad, ops::KronGradOp); -REGISTER_OP_CPU_KERNEL( - kron_grad, ops::KronGradKernel, - ops::KronGradKernel, - ops::KronGradKernel, - ops::KronGradKernel, - ops::KronGradKernel, - ops::KronGradKernel>, - ops::KronGradKernel>); diff --git a/paddle/fluid/operators/kron_op.cu b/paddle/fluid/operators/kron_op.cu deleted file mode 100644 index e5124e65007..00000000000 --- a/paddle/fluid/operators/kron_op.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/kron_op.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/platform/float16.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - kron, ops::KronKernel, - ops::KronKernel, - ops::KronKernel, - ops::KronKernel, - ops::KronKernel, - ops::KronKernel>, - ops::KronKernel>); - -REGISTER_OP_CUDA_KERNEL( - kron_grad, ops::KronGradKernel, - ops::KronGradKernel, - ops::KronGradKernel, - ops::KronGradKernel, - ops::KronGradKernel, - ops::KronGradKernel>, - ops::KronGradKernel>); diff --git a/paddle/fluid/operators/kron_op.h b/paddle/fluid/operators/kron_op.h deleted file mode 100644 index 274b47c03a4..00000000000 --- a/paddle/fluid/operators/kron_op.h +++ /dev/null @@ -1,415 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/for_range.h" -#if defined(__NVCC__) || defined(__HIPCC__) -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" -#include "thrust/device_vector.h" -#endif - -namespace paddle { -namespace operators { - -// Process an element in the output, used with a parallel-for -template -struct KronElemFunctor { - KronElemFunctor(const T* a, const T* b, T* out, const int64_t* shape_b, - const int64_t* stride_a, const int64_t* stride_b, - const int64_t* stride_out, int ndims) - : a_(a), - b_(b), - out_(out), - shape_b_(shape_b), - stride_a_(stride_a), - stride_b_(stride_b), - stride_out_(stride_out), - ndims_(ndims) {} - - HOSTDEVICE void operator()(int64_t idx) const { - // it computes 1 element in the output - int64_t index = idx; - int64_t index_a = 0; - int64_t index_b = 0; - for (int i = 0; i < ndims_; i++) { - auto pos_i = index / stride_out_[i]; - index = index % stride_out_[i]; - auto pos_ai = pos_i / shape_b_[i]; - auto pos_bi = pos_i % shape_b_[i]; - index_a += stride_a_[i] * pos_ai; - index_b += stride_b_[i] * pos_bi; - } - out_[idx] = a_[index_a] * b_[index_b]; - } - - private: - const T* a_; - const T* b_; - T* out_; - const int64_t* shape_b_; - const int64_t* stride_a_; - const int64_t* stride_b_; - const int64_t* stride_out_; - const int ndims_; -}; - -template -struct KronOpFunctor { - void operator()(const DeviceContext& dev_ctx, const framework::Tensor& x, - const framework::Tensor& y, framework::Tensor* out) { - int ndims = out->dims().size(); - int64_t numel = out->numel(); - - const framework::DDim& dim_x = x.dims(); - const framework::DDim& dim_y = y.dims(); - const framework::DDim& dim_out = out->dims(); - const framework::DDim stride_x = phi::stride(dim_x); - const framework::DDim stride_y = phi::stride(dim_y); - const framework::DDim stride_out = phi::stride(dim_out); - - const int64_t *p_stride_x = nullptr, *p_stride_y = nullptr, - *p_stride_out = nullptr, *p_shape_y = nullptr; -#if defined(__NVCC__) || defined(__HIPCC__) - thrust::device_vector d_stride_x(ndims); - thrust::device_vector d_stride_y(ndims); - thrust::device_vector d_stride_out(ndims); - thrust::device_vector d_shape_y(ndims); - thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin()); - thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin()); - thrust::copy(stride_out.Get(), stride_out.Get() + ndims, - d_stride_out.begin()); - thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin()); - - p_stride_x = thrust::raw_pointer_cast(d_stride_x.data()); - p_stride_y = thrust::raw_pointer_cast(d_stride_y.data()); - p_stride_out = thrust::raw_pointer_cast(d_stride_out.data()); - p_shape_y = thrust::raw_pointer_cast(d_shape_y.data()); -#else - p_stride_x = stride_x.Get(); - p_stride_y = stride_y.Get(); - p_stride_out = stride_out.Get(); - p_shape_y = dim_y.Get(); -#endif - - platform::ForRange for_range(dev_ctx, numel); - KronElemFunctor functor(x.data(), y.data(), out->data(), - p_shape_y, p_stride_x, p_stride_y, p_stride_out, - ndims); - for_range(functor); - } -}; - -template -struct KronGradElemFunctor { - KronGradElemFunctor(const T* dout, const T* A, const T* B, T* dout_a, - T* dout_b, const int64_t* stride_dout, - const int64_t* stride_a, const int64_t* stride_b, - const int64_t* shape_b, const int64_t numel_a, - const int64_t numel_b, const int ndims) - : dout_(dout), - A_(A), - B_(B), - dout_a_(dout_a), - dout_b_(dout_b), - stride_dout_(stride_dout), - stride_a_(stride_a), - stride_b_(stride_b), - shape_b_(shape_b), - numel_a_(numel_a), - numel_b_(numel_b), - ndims_(ndims) {} - - HOSTDEVICE void operator()(int64_t idx) { - int64_t index = idx; - int64_t index_a = 0; - int64_t index_b = 0; - for (int i = 0; i < ndims_; i++) { - auto pos_i = index / stride_dout_[i]; - index = index % stride_dout_[i]; - auto pos_ai = pos_i / shape_b_[i]; - auto pos_bi = pos_i % shape_b_[i]; - index_a += stride_a_[i] * pos_ai; - index_b += stride_b_[i] * pos_bi; - } - - if (dout_a_) { - size_t index_out_a = index_a * numel_b_ + index_b; - dout_a_[index_out_a] = dout_[idx] * B_[index_b]; - } - if (dout_b_) { - size_t index_out_b = index_b * numel_a_ + index_a; - dout_b_[index_out_b] = dout_[idx] * A_[index_a]; - } - } - - private: - const T* dout_; - const T* A_; - const T* B_; - T* dout_a_; - T* dout_b_; - const int64_t* stride_dout_; - const int64_t* stride_a_; - const int64_t* stride_b_; - const int64_t* shape_b_; - const int64_t numel_a_; - const int64_t numel_b_; - const int ndims_; -}; - -template -struct KronGradElemFunctor> { - KronGradElemFunctor(const platform::complex* dout, - const platform::complex* A, - const platform::complex* B, - platform::complex* dout_a, - platform::complex* dout_b, const int64_t* stride_dout, - const int64_t* stride_a, const int64_t* stride_b, - const int64_t* shape_b, const int64_t numel_a, - const int64_t numel_b, const int ndims) - : dout_(dout), - A_(A), - B_(B), - dout_a_(dout_a), - dout_b_(dout_b), - stride_dout_(stride_dout), - stride_a_(stride_a), - stride_b_(stride_b), - shape_b_(shape_b), - numel_a_(numel_a), - numel_b_(numel_b), - ndims_(ndims) {} - - HOSTDEVICE void operator()(int64_t idx) { - int64_t index = idx; - int64_t index_a = 0; - int64_t index_b = 0; - for (int i = 0; i < ndims_; i++) { - auto pos_i = index / stride_dout_[i]; - index = index % stride_dout_[i]; - auto pos_ai = pos_i / shape_b_[i]; - auto pos_bi = pos_i % shape_b_[i]; - index_a += stride_a_[i] * pos_ai; - index_b += stride_b_[i] * pos_bi; - } - - if (dout_a_) { - size_t index_out_a = index_a * numel_b_ + index_b; - dout_a_[index_out_a] = - dout_[idx] * - platform::complex(B_[index_b].real, -B_[index_b].imag); - } - if (dout_b_) { - size_t index_out_b = index_b * numel_a_ + index_a; - dout_b_[index_out_b] = - dout_[idx] * - platform::complex(A_[index_a].real, -A_[index_a].imag); - } - } - - private: - const platform::complex* dout_; - const platform::complex* A_; - const platform::complex* B_; - platform::complex* dout_a_; - platform::complex* dout_b_; - const int64_t* stride_dout_; - const int64_t* stride_a_; - const int64_t* stride_b_; - const int64_t* shape_b_; - const int64_t numel_a_; - const int64_t numel_b_; - const int ndims_; -}; - -template -struct KronGradOpFunctor { - void operator()(const DeviceContext& dev_ctx, const framework::Tensor& dout, - const framework::Tensor& x, const framework::Tensor& y, - framework::Tensor* dx, framework::Tensor* dy) { - int ndims = dout.dims().size(); - int64_t numel = dout.numel(); - int64_t numel_x = x.numel(); - int64_t numel_y = y.numel(); - - const framework::DDim& dim_x = x.dims(); - const framework::DDim& dim_y = y.dims(); - const framework::DDim& dim_dout = dout.dims(); - - const framework::DDim stride_x = phi::stride(dim_x); - const framework::DDim stride_y = phi::stride(dim_y); - const framework::DDim stride_dout = phi::stride(dim_dout); - - const int64_t* p_stride_x = nullptr; - const int64_t* p_stride_y = nullptr; - const int64_t* p_stride_dout = nullptr; - const int64_t* p_shape_y = nullptr; -#if defined(__NVCC__) || defined(__HIPCC__) - thrust::device_vector d_stride_x(ndims); - thrust::device_vector d_stride_y(ndims); - thrust::device_vector d_stride_dout(ndims); - thrust::device_vector d_shape_y(ndims); - thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin()); - thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin()); - thrust::copy(stride_dout.Get(), stride_dout.Get() + ndims, - d_stride_dout.begin()); - thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin()); - - p_stride_x = thrust::raw_pointer_cast(d_stride_x.data()); - p_stride_y = thrust::raw_pointer_cast(d_stride_y.data()); - p_stride_dout = thrust::raw_pointer_cast(d_stride_dout.data()); - p_shape_y = thrust::raw_pointer_cast(d_shape_y.data()); -#else - p_stride_x = stride_x.Get(); - p_stride_y = stride_y.Get(); - p_stride_dout = stride_dout.Get(); - p_shape_y = dim_y.Get(); -#endif - // dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y) - // dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x) - framework::Tensor dout_x; - T* p_dout_x = nullptr; - if (dx) { - dout_x.mutable_data({numel_x, numel_y}, dev_ctx.GetPlace()); - p_dout_x = dout_x.data(); - } - framework::Tensor dout_y; - T* p_dout_y = nullptr; - if (dy) { - dout_y.mutable_data({numel_y, numel_x}, dev_ctx.GetPlace()); - p_dout_y = dout_y.data(); - } - - platform::ForRange for_range(dev_ctx, numel); - KronGradElemFunctor func(dout.data(), x.data(), y.data(), - p_dout_x, p_dout_y, p_stride_dout, p_stride_x, - p_stride_y, p_shape_y, numel_x, numel_y, ndims); - for_range(func); - -// reduce_sum along aixs 1 -#if defined(__NVCC__) || defined(__HIPCC__) - auto stream = dev_ctx.stream(); // it is a cuda device_context - if (dx) { - TensorReduceImpl>( - dev_ctx, dout_x, dx, kps::IdentityFunctor(), {1}, stream); - } - if (dy) { - TensorReduceImpl>( - dev_ctx, dout_y, dy, kps::IdentityFunctor(), {1}, stream); - } -#else - auto* place = dev_ctx.eigen_device(); - Eigen::array reduce_dim = {1}; - if (dx) { - auto eigen_dout_x = framework::EigenMatrix::Reshape(dout_x, 1); - auto eigen_vec_dx = framework::EigenVector::Flatten(*dx); - eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim); - } - if (dy) { - auto eigen_dout_y = framework::EigenMatrix::Reshape(dout_y, 1); - auto eigen_vec_dy = framework::EigenVector::Flatten(*dy); - eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim); - } -#endif - } -}; - -inline framework::Tensor UnsqueezeTo(const framework::Tensor& src, int ndims) { - const framework::DDim& shape = src.dims(); - int rank = shape.size(); - framework::Tensor res; - res.ShareDataWith(src); - PADDLE_ENFORCE_LE( - rank, ndims, - platform::errors::InvalidArgument( - "The input Tensor's rank should be less than or equal to ndims" - "Received input Tensor's rank = %d, ndims = %d", - rank, ndims)); - if (rank < ndims) { - std::vector new_dim(ndims, 1); - for (int i = ndims - rank; i < ndims; i++) { - new_dim[i] = shape[i - ndims + rank]; - } - res.Resize(phi::make_ddim(new_dim)); - } - return res; -} - -template -class KronKernel : public framework::OpKernel { - public: - virtual void Compute(const framework::ExecutionContext& ctx) const { - auto& dev_ctx = ctx.template device_context(); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - int ndims = out->dims().size(); - framework::Tensor xx = UnsqueezeTo(*x, ndims); - framework::Tensor yy = UnsqueezeTo(*y, ndims); - - KronOpFunctor func; - func(dev_ctx, xx, yy, out); - } -}; - -template -class KronGradKernel : public framework::OpKernel { - public: - virtual void Compute(const framework::ExecutionContext& ctx) const { - auto& dev_ctx = ctx.template device_context(); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - if (dx) { - dx->mutable_data(ctx.GetPlace()); - } - if (dy) { - dy->mutable_data(ctx.GetPlace()); - } - - int ndims = dout->dims().size(); - framework::Tensor xx = UnsqueezeTo(*x, ndims); - framework::Tensor yy = UnsqueezeTo(*y, ndims); - - framework::Tensor* pdxx = nullptr; - framework::Tensor* pdyy = nullptr; - framework::Tensor dxx; - framework::Tensor dyy; - if (dx) { - dxx = UnsqueezeTo(*dx, ndims); - pdxx = &dxx; - } - - if (dy) { - dyy = UnsqueezeTo(*dy, ndims); - pdyy = &dyy; - } - - KronGradOpFunctor func; - func(dev_ctx, *dout, xx, yy, pdxx, pdyy); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/scatter_op_npu.cc b/paddle/fluid/operators/scatter_op_npu.cc index 815984ac307..d5ef95269b4 100644 --- a/paddle/fluid/operators/scatter_op_npu.cc +++ b/paddle/fluid/operators/scatter_op_npu.cc @@ -16,7 +16,7 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/kron_op.h" +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" namespace paddle { diff --git a/paddle/phi/kernels/cpu/kron_grad_kernel.cc b/paddle/phi/kernels/cpu/kron_grad_kernel.cc new file mode 100644 index 00000000000..01f5e5404b6 --- /dev/null +++ b/paddle/phi/kernels/cpu/kron_grad_kernel.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/kron_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/kron_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(kron_grad, + CPU, + ALL_LAYOUT, + phi::KronGradKernel, + int, + int64_t, + float, + double, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/kron_kernel.cc b/paddle/phi/kernels/cpu/kron_kernel.cc new file mode 100644 index 00000000000..aaea509dc76 --- /dev/null +++ b/paddle/phi/kernels/cpu/kron_kernel.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/kron_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/kron_kernel_impl.h" + +PD_REGISTER_KERNEL(kron, + CPU, + ALL_LAYOUT, + phi::KronKernel, + int, + int64_t, + float, + double, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/kron_grad_kernel.cu b/paddle/phi/kernels/gpu/kron_grad_kernel.cu new file mode 100644 index 00000000000..13ef2adaab3 --- /dev/null +++ b/paddle/phi/kernels/gpu/kron_grad_kernel.cu @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/kron_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/kron_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(kron_grad, + GPU, + ALL_LAYOUT, + phi::KronGradKernel, + int, + int64_t, + float, + double, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/kron_kernel.cu b/paddle/phi/kernels/gpu/kron_kernel.cu new file mode 100644 index 00000000000..a2124fd5af7 --- /dev/null +++ b/paddle/phi/kernels/gpu/kron_kernel.cu @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/kron_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/kron_kernel_impl.h" + +PD_REGISTER_KERNEL(kron, + GPU, + ALL_LAYOUT, + phi::KronKernel, + int, + int64_t, + float, + double, + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/impl/kron_grad_kernel_impl.h b/paddle/phi/kernels/impl/kron_grad_kernel_impl.h new file mode 100644 index 00000000000..30297b53eab --- /dev/null +++ b/paddle/phi/kernels/impl/kron_grad_kernel_impl.h @@ -0,0 +1,295 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/impl/kron_kernel_impl.h" + +namespace phi { + +template +struct KronGradElemFunctor { + KronGradElemFunctor(const T* dout, + const T* A, + const T* B, + T* dout_a, + T* dout_b, + const int64_t* stride_dout, + const int64_t* stride_a, + const int64_t* stride_b, + const int64_t* shape_b, + const int64_t numel_a, + const int64_t numel_b, + const int ndims) + : dout_(dout), + A_(A), + B_(B), + dout_a_(dout_a), + dout_b_(dout_b), + stride_dout_(stride_dout), + stride_a_(stride_a), + stride_b_(stride_b), + shape_b_(shape_b), + numel_a_(numel_a), + numel_b_(numel_b), + ndims_(ndims) {} + + HOSTDEVICE void operator()(int64_t idx) { + int64_t index = idx; + int64_t index_a = 0; + int64_t index_b = 0; + for (int i = 0; i < ndims_; i++) { + auto pos_i = index / stride_dout_[i]; + index = index % stride_dout_[i]; + auto pos_ai = pos_i / shape_b_[i]; + auto pos_bi = pos_i % shape_b_[i]; + index_a += stride_a_[i] * pos_ai; + index_b += stride_b_[i] * pos_bi; + } + + if (dout_a_) { + size_t index_out_a = index_a * numel_b_ + index_b; + dout_a_[index_out_a] = dout_[idx] * B_[index_b]; + } + if (dout_b_) { + size_t index_out_b = index_b * numel_a_ + index_a; + dout_b_[index_out_b] = dout_[idx] * A_[index_a]; + } + } + + private: + const T* dout_; + const T* A_; + const T* B_; + T* dout_a_; + T* dout_b_; + const int64_t* stride_dout_; + const int64_t* stride_a_; + const int64_t* stride_b_; + const int64_t* shape_b_; + const int64_t numel_a_; + const int64_t numel_b_; + const int ndims_; +}; + +template +struct KronGradElemFunctor> { + KronGradElemFunctor(const dtype::complex* dout, + const dtype::complex* A, + const dtype::complex* B, + dtype::complex* dout_a, + dtype::complex* dout_b, + const int64_t* stride_dout, + const int64_t* stride_a, + const int64_t* stride_b, + const int64_t* shape_b, + const int64_t numel_a, + const int64_t numel_b, + const int ndims) + : dout_(dout), + A_(A), + B_(B), + dout_a_(dout_a), + dout_b_(dout_b), + stride_dout_(stride_dout), + stride_a_(stride_a), + stride_b_(stride_b), + shape_b_(shape_b), + numel_a_(numel_a), + numel_b_(numel_b), + ndims_(ndims) {} + + HOSTDEVICE void operator()(int64_t idx) { + int64_t index = idx; + int64_t index_a = 0; + int64_t index_b = 0; + for (int i = 0; i < ndims_; i++) { + auto pos_i = index / stride_dout_[i]; + index = index % stride_dout_[i]; + auto pos_ai = pos_i / shape_b_[i]; + auto pos_bi = pos_i % shape_b_[i]; + index_a += stride_a_[i] * pos_ai; + index_b += stride_b_[i] * pos_bi; + } + + if (dout_a_) { + size_t index_out_a = index_a * numel_b_ + index_b; + dout_a_[index_out_a] = + dout_[idx] * dtype::complex(B_[index_b].real, -B_[index_b].imag); + } + if (dout_b_) { + size_t index_out_b = index_b * numel_a_ + index_a; + dout_b_[index_out_b] = + dout_[idx] * dtype::complex(A_[index_a].real, -A_[index_a].imag); + } + } + + private: + const dtype::complex* dout_; + const dtype::complex* A_; + const dtype::complex* B_; + dtype::complex* dout_a_; + dtype::complex* dout_b_; + const int64_t* stride_dout_; + const int64_t* stride_a_; + const int64_t* stride_b_; + const int64_t* shape_b_; + const int64_t numel_a_; + const int64_t numel_b_; + const int ndims_; +}; + +template +struct KronGradOpFunctor { + void operator()(const Context& dev_ctx, + const DenseTensor& dout, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* dx, + DenseTensor* dy) { + int ndims = dout.dims().size(); + int64_t numel = dout.numel(); + int64_t numel_x = x.numel(); + int64_t numel_y = y.numel(); + + const phi::DDim& dim_x = x.dims(); + const phi::DDim& dim_y = y.dims(); + const phi::DDim& dim_dout = dout.dims(); + + const phi::DDim stride_x = phi::stride(dim_x); + const phi::DDim stride_y = phi::stride(dim_y); + const phi::DDim stride_dout = phi::stride(dim_dout); + + const int64_t* p_stride_x = nullptr; + const int64_t* p_stride_y = nullptr; + const int64_t* p_stride_dout = nullptr; + const int64_t* p_shape_y = nullptr; +#if defined(__NVCC__) || defined(__HIPCC__) + thrust::device_vector d_stride_x(ndims); + thrust::device_vector d_stride_y(ndims); + thrust::device_vector d_stride_dout(ndims); + thrust::device_vector d_shape_y(ndims); + thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin()); + thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin()); + thrust::copy( + stride_dout.Get(), stride_dout.Get() + ndims, d_stride_dout.begin()); + thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin()); + + p_stride_x = thrust::raw_pointer_cast(d_stride_x.data()); + p_stride_y = thrust::raw_pointer_cast(d_stride_y.data()); + p_stride_dout = thrust::raw_pointer_cast(d_stride_dout.data()); + p_shape_y = thrust::raw_pointer_cast(d_shape_y.data()); +#else + p_stride_x = stride_x.Get(); + p_stride_y = stride_y.Get(); + p_stride_dout = stride_dout.Get(); + p_shape_y = dim_y.Get(); +#endif + // dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y) + // dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x) + DenseTensor dout_x; + T* p_dout_x = nullptr; + if (dx) { + dout_x.Resize({numel_x, numel_y}); + dev_ctx.template Alloc(&dout_x); + p_dout_x = dout_x.data(); + } + DenseTensor dout_y; + T* p_dout_y = nullptr; + if (dy) { + dout_y.Resize({numel_y, numel_x}); + dev_ctx.template Alloc(&dout_y); + p_dout_y = dout_y.data(); + } + + funcs::ForRange for_range(dev_ctx, numel); + KronGradElemFunctor func(dout.data(), + x.data(), + y.data(), + p_dout_x, + p_dout_y, + p_stride_dout, + p_stride_x, + p_stride_y, + p_shape_y, + numel_x, + numel_y, + ndims); + for_range(func); + +// reduce_sum along aixs 1 +#if defined(__NVCC__) || defined(__HIPCC__) + auto stream = dev_ctx.stream(); // it is a cuda device_context + if (dx) { + funcs::ReduceKernel>( + dev_ctx, dout_x, dx, kps::IdentityFunctor(), {1}); + } + if (dy) { + funcs::ReduceKernel>( + dev_ctx, dout_y, dy, kps::IdentityFunctor(), {1}); + } +#else + auto* place = dev_ctx.eigen_device(); + Eigen::array reduce_dim = {1}; + if (dx) { + auto eigen_dout_x = EigenMatrix::Reshape(dout_x, 1); + auto eigen_vec_dx = EigenVector::Flatten(*dx); + eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim); + } + if (dy) { + auto eigen_dout_y = EigenMatrix::Reshape(dout_y, 1); + auto eigen_vec_dy = EigenVector::Flatten(*dy); + eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim); + } +#endif + } +}; + +template +void KronGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad) { + if (x_grad) { + ctx.template Alloc(x_grad); + } + if (y_grad) { + ctx.template Alloc(y_grad); + } + + int ndims = out_grad.dims().size(); + DenseTensor xx = UnsqueezeTo(x, ndims); + DenseTensor yy = UnsqueezeTo(y, ndims); + + DenseTensor* pdxx = nullptr; + DenseTensor* pdyy = nullptr; + DenseTensor dxx; + DenseTensor dyy; + if (x_grad) { + dxx = UnsqueezeTo(*x_grad, ndims); + pdxx = &dxx; + } + + if (y_grad) { + dyy = UnsqueezeTo(*y_grad, ndims); + pdyy = &dyy; + } + + KronGradOpFunctor func; + func(ctx, out_grad, xx, yy, pdxx, pdyy); +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/kron_kernel_impl.h b/paddle/phi/kernels/impl/kron_kernel_impl.h new file mode 100644 index 00000000000..47c76f59df2 --- /dev/null +++ b/paddle/phi/kernels/impl/kron_kernel_impl.h @@ -0,0 +1,167 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/phi/kernels/funcs/reduce_function.h" +#include "thrust/device_vector.h" +#endif + +namespace phi { + +inline DenseTensor UnsqueezeTo(const DenseTensor& src, int ndims) { + const phi::DDim& shape = src.dims(); + int rank = shape.size(); + DenseTensor res; + res.ShareDataWith(src); + PADDLE_ENFORCE_LE( + rank, + ndims, + errors::InvalidArgument( + "The input Tensor's rank should be less than or equal to ndims" + "Received input Tensor's rank = %d, ndims = %d", + rank, + ndims)); + if (rank < ndims) { + std::vector new_dim(ndims, 1); + for (int i = ndims - rank; i < ndims; i++) { + new_dim[i] = shape[i - ndims + rank]; + } + res.Resize(phi::make_ddim(new_dim)); + } + return res; +} + +template +struct KronElemFunctor { + KronElemFunctor(const T* a, + const T* b, + T* out, + const int64_t* shape_b, + const int64_t* stride_a, + const int64_t* stride_b, + const int64_t* stride_out, + int ndims) + : a_(a), + b_(b), + out_(out), + shape_b_(shape_b), + stride_a_(stride_a), + stride_b_(stride_b), + stride_out_(stride_out), + ndims_(ndims) {} + + HOSTDEVICE void operator()(int64_t idx) const { + // it computes 1 element in the output + int64_t index = idx; + int64_t index_a = 0; + int64_t index_b = 0; + for (int i = 0; i < ndims_; i++) { + auto pos_i = index / stride_out_[i]; + index = index % stride_out_[i]; + auto pos_ai = pos_i / shape_b_[i]; + auto pos_bi = pos_i % shape_b_[i]; + index_a += stride_a_[i] * pos_ai; + index_b += stride_b_[i] * pos_bi; + } + out_[idx] = a_[index_a] * b_[index_b]; + } + + private: + const T* a_; + const T* b_; + T* out_; + const int64_t* shape_b_; + const int64_t* stride_a_; + const int64_t* stride_b_; + const int64_t* stride_out_; + const int ndims_; +}; + +template +struct KronOpFunctor { + void operator()(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + int ndims = out->dims().size(); + int64_t numel = out->numel(); + + const phi::DDim& dim_x = x.dims(); + const phi::DDim& dim_y = y.dims(); + const phi::DDim& dim_out = out->dims(); + const phi::DDim stride_x = phi::stride(dim_x); + const phi::DDim stride_y = phi::stride(dim_y); + const phi::DDim stride_out = phi::stride(dim_out); + + const int64_t *p_stride_x = nullptr, *p_stride_y = nullptr, + *p_stride_out = nullptr, *p_shape_y = nullptr; +#if defined(__NVCC__) || defined(__HIPCC__) + thrust::device_vector d_stride_x(ndims); + thrust::device_vector d_stride_y(ndims); + thrust::device_vector d_stride_out(ndims); + thrust::device_vector d_shape_y(ndims); + thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin()); + thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin()); + thrust::copy( + stride_out.Get(), stride_out.Get() + ndims, d_stride_out.begin()); + thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin()); + + p_stride_x = thrust::raw_pointer_cast(d_stride_x.data()); + p_stride_y = thrust::raw_pointer_cast(d_stride_y.data()); + p_stride_out = thrust::raw_pointer_cast(d_stride_out.data()); + p_shape_y = thrust::raw_pointer_cast(d_shape_y.data()); +#else + p_stride_x = stride_x.Get(); + p_stride_y = stride_y.Get(); + p_stride_out = stride_out.Get(); + p_shape_y = dim_y.Get(); +#endif + + funcs::ForRange for_range(dev_ctx, numel); + KronElemFunctor functor(x.data(), + y.data(), + out->data(), + p_shape_y, + p_stride_x, + p_stride_y, + p_stride_out, + ndims); + for_range(functor); + } +}; + +template +void KronKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out) { + ctx.template Alloc(out); + + int ndims = out->dims().size(); + DenseTensor xx = UnsqueezeTo(x, ndims); + DenseTensor yy = UnsqueezeTo(y, ndims); + + KronOpFunctor func; + func(ctx, xx, yy, out); +} + +} // namespace phi diff --git a/paddle/phi/kernels/kron_grad_kernel.h b/paddle/phi/kernels/kron_grad_kernel.h new file mode 100644 index 00000000000..3daa9dcfba9 --- /dev/null +++ b/paddle/phi/kernels/kron_grad_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void KronGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out_grad, + DenseTensor* x_grad, + DenseTensor* y_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/kron_kernel.h b/paddle/phi/kernels/kron_kernel.h new file mode 100644 index 00000000000..4451ac757a9 --- /dev/null +++ b/paddle/phi/kernels/kron_kernel.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void KronKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/kron_sig.cc b/paddle/phi/ops/compat/kron_sig.cc new file mode 100644 index 00000000000..06b6545f58e --- /dev/null +++ b/paddle/phi/ops/compat/kron_sig.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature KronGradOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("kron_grad", + {"X", "Y", GradVarName("Out")}, + {}, + {GradVarName("X"), GradVarName("Y")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(kron_grad, phi::KronGradOpArgumentMapping); -- GitLab