/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" #if __NVCC__ #include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "thrust/device_vector.h" #endif namespace paddle { namespace operators { using complex64 = paddle::platform::complex64; using complex128 = paddle::platform::complex128; // Process an element in the output, used with a parallel-for template struct KronElemFunctor { KronElemFunctor(const T* a, const T* b, T* out, const int64_t* shape_b, const int64_t* stride_a, const int64_t* stride_b, const int64_t* stride_out, int ndims) : a_(a), b_(b), out_(out), shape_b_(shape_b), stride_a_(stride_a), stride_b_(stride_b), stride_out_(stride_out), ndims_(ndims) {} HOSTDEVICE void operator()(int64_t idx) const { // it computes 1 element in the output int64_t index = idx; int64_t index_a = 0; int64_t index_b = 0; for (int i = 0; i < ndims_; i++) { auto pos_i = index / stride_out_[i]; index = index % stride_out_[i]; auto pos_ai = pos_i / shape_b_[i]; auto pos_bi = pos_i % shape_b_[i]; index_a += stride_a_[i] * pos_ai; index_b += stride_b_[i] * pos_bi; } out_[idx] = a_[index_a] * b_[index_b]; } private: const T* a_; const T* b_; T* out_; const int64_t* shape_b_; const int64_t* stride_a_; const int64_t* stride_b_; const int64_t* stride_out_; const int ndims_; }; template struct KronOpFunctor { void operator()(const DeviceContext& dev_ctx, const framework::Tensor& x, const framework::Tensor& y, framework::Tensor* out) { int ndims = out->dims().size(); int64_t numel = out->numel(); const framework::DDim& dim_x = x.dims(); const framework::DDim& dim_y = y.dims(); const framework::DDim& dim_out = out->dims(); const framework::DDim stride_x = framework::stride(dim_x); const framework::DDim stride_y = framework::stride(dim_y); const framework::DDim stride_out = framework::stride(dim_out); const int64_t *p_stride_x = nullptr, *p_stride_y = nullptr, *p_stride_out = nullptr, *p_shape_y = nullptr; #if __NVCC__ thrust::device_vector d_stride_x(ndims); thrust::device_vector d_stride_y(ndims); thrust::device_vector d_stride_out(ndims); thrust::device_vector d_shape_y(ndims); thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin()); thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin()); thrust::copy(stride_out.Get(), stride_out.Get() + ndims, d_stride_out.begin()); thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin()); p_stride_x = thrust::raw_pointer_cast(d_stride_x.data()); p_stride_y = thrust::raw_pointer_cast(d_stride_y.data()); p_stride_out = thrust::raw_pointer_cast(d_stride_out.data()); p_shape_y = thrust::raw_pointer_cast(d_shape_y.data()); #else p_stride_x = stride_x.Get(); p_stride_y = stride_y.Get(); p_stride_out = stride_out.Get(); p_shape_y = dim_y.Get(); #endif platform::ForRange for_range(dev_ctx, numel); KronElemFunctor functor(x.data(), y.data(), out->data(), p_shape_y, p_stride_x, p_stride_y, p_stride_out, ndims); for_range(functor); } }; template struct KronGradElemFunctor { KronGradElemFunctor(const T* dout, const T* A, const T* B, T* dout_a, T* dout_b, const int64_t* stride_dout, const int64_t* stride_a, const int64_t* stride_b, const int64_t* shape_b, const int64_t numel_a, const int64_t numel_b, const int ndims) : dout_(dout), A_(A), B_(B), dout_a_(dout_a), dout_b_(dout_b), stride_dout_(stride_dout), stride_a_(stride_a), stride_b_(stride_b), shape_b_(shape_b), numel_a_(numel_a), numel_b_(numel_b), ndims_(ndims) {} HOSTDEVICE void operator()(int64_t idx) { int64_t index = idx; int64_t index_a = 0; int64_t index_b = 0; for (int i = 0; i < ndims_; i++) { auto pos_i = index / stride_dout_[i]; index = index % stride_dout_[i]; auto pos_ai = pos_i / shape_b_[i]; auto pos_bi = pos_i % shape_b_[i]; index_a += stride_a_[i] * pos_ai; index_b += stride_b_[i] * pos_bi; } if (dout_a_) { size_t index_out_a = index_a * numel_b_ + index_b; dout_a_[index_out_a] = dout_[idx] * B_[index_b]; } if (dout_b_) { size_t index_out_b = index_b * numel_a_ + index_a; dout_b_[index_out_b] = dout_[idx] * A_[index_a]; } } private: const T* dout_; const T* A_; const T* B_; T* dout_a_; T* dout_b_; const int64_t* stride_dout_; const int64_t* stride_a_; const int64_t* stride_b_; const int64_t* shape_b_; const int64_t numel_a_; const int64_t numel_b_; const int ndims_; }; template <> struct KronGradElemFunctor { KronGradElemFunctor(const complex64* dout, const complex64* A, const complex64* B, complex64* dout_a, complex64* dout_b, const int64_t* stride_dout, const int64_t* stride_a, const int64_t* stride_b, const int64_t* shape_b, const int64_t numel_a, const int64_t numel_b, const int ndims) : dout_(dout), A_(A), B_(B), dout_a_(dout_a), dout_b_(dout_b), stride_dout_(stride_dout), stride_a_(stride_a), stride_b_(stride_b), shape_b_(shape_b), numel_a_(numel_a), numel_b_(numel_b), ndims_(ndims) {} HOSTDEVICE void operator()(int64_t idx) { int64_t index = idx; int64_t index_a = 0; int64_t index_b = 0; for (int i = 0; i < ndims_; i++) { auto pos_i = index / stride_dout_[i]; index = index % stride_dout_[i]; auto pos_ai = pos_i / shape_b_[i]; auto pos_bi = pos_i % shape_b_[i]; index_a += stride_a_[i] * pos_ai; index_b += stride_b_[i] * pos_bi; } if (dout_a_) { size_t index_out_a = index_a * numel_b_ + index_b; dout_a_[index_out_a] = dout_[idx] * complex64(B_[index_b].real, -B_[index_b].imag); } if (dout_b_) { size_t index_out_b = index_b * numel_a_ + index_a; dout_b_[index_out_b] = dout_[idx] * complex64(A_[index_a].real, -A_[index_a].imag); } } private: const complex64* dout_; const complex64* A_; const complex64* B_; complex64* dout_a_; complex64* dout_b_; const int64_t* stride_dout_; const int64_t* stride_a_; const int64_t* stride_b_; const int64_t* shape_b_; const int64_t numel_a_; const int64_t numel_b_; const int ndims_; }; template <> struct KronGradElemFunctor { KronGradElemFunctor(const complex128* dout, const complex128* A, const complex128* B, complex128* dout_a, complex128* dout_b, const int64_t* stride_dout, const int64_t* stride_a, const int64_t* stride_b, const int64_t* shape_b, const int64_t numel_a, const int64_t numel_b, const int ndims) : dout_(dout), A_(A), B_(B), dout_a_(dout_a), dout_b_(dout_b), stride_dout_(stride_dout), stride_a_(stride_a), stride_b_(stride_b), shape_b_(shape_b), numel_a_(numel_a), numel_b_(numel_b), ndims_(ndims) {} HOSTDEVICE void operator()(int64_t idx) { int64_t index = idx; int64_t index_a = 0; int64_t index_b = 0; for (int i = 0; i < ndims_; i++) { auto pos_i = index / stride_dout_[i]; index = index % stride_dout_[i]; auto pos_ai = pos_i / shape_b_[i]; auto pos_bi = pos_i % shape_b_[i]; index_a += stride_a_[i] * pos_ai; index_b += stride_b_[i] * pos_bi; } if (dout_a_) { size_t index_out_a = index_a * numel_b_ + index_b; dout_a_[index_out_a] = dout_[idx] * complex128(B_[index_b].real, -B_[index_b].imag); } if (dout_b_) { size_t index_out_b = index_b * numel_a_ + index_a; dout_b_[index_out_b] = dout_[idx] * complex128(A_[index_a].real, -A_[index_a].imag); } } private: const complex128* dout_; const complex128* A_; const complex128* B_; complex128* dout_a_; complex128* dout_b_; const int64_t* stride_dout_; const int64_t* stride_a_; const int64_t* stride_b_; const int64_t* shape_b_; const int64_t numel_a_; const int64_t numel_b_; const int ndims_; }; template struct IdentityFunctor { HOSTDEVICE explicit inline IdentityFunctor() {} HOSTDEVICE inline T operator()(const T& x) const { return x; } }; template struct KronGradOpFunctor { void operator()(const DeviceContext& dev_ctx, const framework::Tensor& dout, const framework::Tensor& x, const framework::Tensor& y, framework::Tensor* dx, framework::Tensor* dy) { int ndims = dout.dims().size(); int64_t numel = dout.numel(); int64_t numel_x = x.numel(); int64_t numel_y = y.numel(); const framework::DDim& dim_x = x.dims(); const framework::DDim& dim_y = y.dims(); const framework::DDim& dim_dout = dout.dims(); const framework::DDim stride_x = framework::stride(dim_x); const framework::DDim stride_y = framework::stride(dim_y); const framework::DDim stride_dout = framework::stride(dim_dout); const int64_t* p_stride_x = nullptr; const int64_t* p_stride_y = nullptr; const int64_t* p_stride_dout = nullptr; const int64_t* p_shape_y = nullptr; #if __NVCC__ thrust::device_vector d_stride_x(ndims); thrust::device_vector d_stride_y(ndims); thrust::device_vector d_stride_dout(ndims); thrust::device_vector d_shape_y(ndims); thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin()); thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin()); thrust::copy(stride_dout.Get(), stride_dout.Get() + ndims, d_stride_dout.begin()); thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin()); p_stride_x = thrust::raw_pointer_cast(d_stride_x.data()); p_stride_y = thrust::raw_pointer_cast(d_stride_y.data()); p_stride_dout = thrust::raw_pointer_cast(d_stride_dout.data()); p_shape_y = thrust::raw_pointer_cast(d_shape_y.data()); #else p_stride_x = stride_x.Get(); p_stride_y = stride_y.Get(); p_stride_dout = stride_dout.Get(); p_shape_y = dim_y.Get(); #endif // dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y) // dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x) framework::Tensor dout_x; T* p_dout_x = nullptr; if (dx) { dout_x.mutable_data({numel_x, numel_y}, dev_ctx.GetPlace()); p_dout_x = dout_x.data(); } framework::Tensor dout_y; T* p_dout_y = nullptr; if (dy) { dout_y.mutable_data({numel_y, numel_x}, dev_ctx.GetPlace()); p_dout_y = dout_y.data(); } platform::ForRange for_range(dev_ctx, numel); KronGradElemFunctor func(dout.data(), x.data(), y.data(), p_dout_x, p_dout_y, p_stride_dout, p_stride_x, p_stride_y, p_shape_y, numel_x, numel_y, ndims); for_range(func); // reduce_sum along aixs 1 #if __NVCC__ auto stream = dev_ctx.stream(); // it is a cuda device_context if (dx) { TensorReduce>( dout_x, dx, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), stream); } if (dy) { TensorReduce>( dout_y, dy, {1}, static_cast(0), cub::Sum(), IdentityFunctor(), stream); } #else auto* place = dev_ctx.eigen_device(); Eigen::array reduce_dim = {1}; if (dx) { auto eigen_dout_x = framework::EigenMatrix::Reshape(dout_x, 1); auto eigen_vec_dx = framework::EigenVector::Flatten(*dx); eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim); } if (dy) { auto eigen_dout_y = framework::EigenMatrix::Reshape(dout_y, 1); auto eigen_vec_dy = framework::EigenVector::Flatten(*dy); eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim); } #endif } }; inline framework::Tensor UnsqueezeTo(const framework::Tensor& src, int ndims) { const framework::DDim& shape = src.dims(); int rank = shape.size(); framework::Tensor res; res.ShareDataWith(src); PADDLE_ENFORCE_LE( rank, ndims, platform::errors::InvalidArgument( "The input Tensor's rank should be less than or equal to ndims" "Received input Tensor's rank = %d, ndims = %d", rank, ndims)); if (rank < ndims) { std::vector new_dim(ndims, 1); for (int i = ndims - rank; i < ndims; i++) { new_dim[i] = shape[i - ndims + rank]; } res.Resize(framework::make_ddim(new_dim)); } return res; } template class KronKernel : public framework::OpKernel { public: virtual void Compute(const framework::ExecutionContext& ctx) const { auto& dev_ctx = ctx.template device_context(); auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); int ndims = out->dims().size(); framework::Tensor xx = UnsqueezeTo(*x, ndims); framework::Tensor yy = UnsqueezeTo(*y, ndims); KronOpFunctor func; func(dev_ctx, xx, yy, out); } }; template class KronGradKernel : public framework::OpKernel { public: virtual void Compute(const framework::ExecutionContext& ctx) const { auto& dev_ctx = ctx.template device_context(); auto* x = ctx.Input("X"); auto* y = ctx.Input("Y"); auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); auto* dy = ctx.Output(framework::GradVarName("Y")); if (dx) { dx->mutable_data(ctx.GetPlace()); } if (dy) { dy->mutable_data(ctx.GetPlace()); } int ndims = dout->dims().size(); framework::Tensor xx = UnsqueezeTo(*x, ndims); framework::Tensor yy = UnsqueezeTo(*y, ndims); framework::Tensor* pdxx = nullptr; framework::Tensor* pdyy = nullptr; framework::Tensor dxx; framework::Tensor dyy; if (dx) { dxx = UnsqueezeTo(*dx, ndims); pdxx = &dxx; } if (dy) { dyy = UnsqueezeTo(*dy, ndims); pdyy = &dyy; } KronGradOpFunctor func; func(dev_ctx, *dout, xx, yy, pdxx, pdyy); } }; } // namespace operators } // namespace paddle