// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/phi/kernels/impl/kron_kernel_impl.h" namespace phi { template struct KronGradElemFunctor { KronGradElemFunctor(const T *dout, const T *A, const T *B, T *dout_a, T *dout_b, const int64_t *stride_dout, const int64_t *stride_a, const int64_t *stride_b, const int64_t *shape_b, const int64_t numel_a, const int64_t numel_b, const int ndims) : dout_(dout), A_(A), B_(B), dout_a_(dout_a), dout_b_(dout_b), stride_dout_(stride_dout), stride_a_(stride_a), stride_b_(stride_b), shape_b_(shape_b), numel_a_(numel_a), numel_b_(numel_b), ndims_(ndims) {} HOSTDEVICE void operator()(int64_t idx) { int64_t index = idx; int64_t index_a = 0; int64_t index_b = 0; for (int i = 0; i < ndims_; i++) { auto pos_i = index / stride_dout_[i]; index = index % stride_dout_[i]; auto pos_ai = pos_i / shape_b_[i]; auto pos_bi = pos_i % shape_b_[i]; index_a += stride_a_[i] * pos_ai; index_b += stride_b_[i] * pos_bi; } if (dout_a_) { size_t index_out_a = index_a * numel_b_ + index_b; dout_a_[index_out_a] = dout_[idx] * B_[index_b]; } if (dout_b_) { size_t index_out_b = index_b * numel_a_ + index_a; dout_b_[index_out_b] = dout_[idx] * A_[index_a]; } } private: const T *dout_; const T *A_; const T *B_; T *dout_a_; T *dout_b_; const int64_t *stride_dout_; const int64_t *stride_a_; const int64_t *stride_b_; const int64_t *shape_b_; const int64_t numel_a_; const int64_t numel_b_; const int ndims_; }; template struct KronGradElemFunctor> { KronGradElemFunctor(const dtype::complex *dout, const dtype::complex *A, const dtype::complex *B, dtype::complex *dout_a, dtype::complex *dout_b, const int64_t *stride_dout, const int64_t *stride_a, const int64_t *stride_b, const int64_t *shape_b, const int64_t numel_a, const int64_t numel_b, const int ndims) : dout_(dout), A_(A), B_(B), dout_a_(dout_a), dout_b_(dout_b), stride_dout_(stride_dout), stride_a_(stride_a), stride_b_(stride_b), shape_b_(shape_b), numel_a_(numel_a), numel_b_(numel_b), ndims_(ndims) {} HOSTDEVICE void operator()(int64_t idx) { int64_t index = idx; int64_t index_a = 0; int64_t index_b = 0; for (int i = 0; i < ndims_; i++) { auto pos_i = index / stride_dout_[i]; index = index % stride_dout_[i]; auto pos_ai = pos_i / shape_b_[i]; auto pos_bi = pos_i % shape_b_[i]; index_a += stride_a_[i] * pos_ai; index_b += stride_b_[i] * pos_bi; } if (dout_a_) { size_t index_out_a = index_a * numel_b_ + index_b; dout_a_[index_out_a] = dout_[idx] * dtype::complex(B_[index_b].real, -B_[index_b].imag); } if (dout_b_) { size_t index_out_b = index_b * numel_a_ + index_a; dout_b_[index_out_b] = dout_[idx] * dtype::complex(A_[index_a].real, -A_[index_a].imag); } } private: const dtype::complex *dout_; const dtype::complex *A_; const dtype::complex *B_; dtype::complex *dout_a_; dtype::complex *dout_b_; const int64_t *stride_dout_; const int64_t *stride_a_; const int64_t *stride_b_; const int64_t *shape_b_; const int64_t numel_a_; const int64_t numel_b_; const int ndims_; }; template struct KronGradOpFunctor { void operator()(const Context &dev_ctx, const DenseTensor &dout, const DenseTensor &x, const DenseTensor &y, DenseTensor *dx, DenseTensor *dy) { int ndims = dout.dims().size(); int64_t numel = dout.numel(); int64_t numel_x = x.numel(); int64_t numel_y = y.numel(); const phi::DDim &dim_x = x.dims(); const phi::DDim &dim_y = y.dims(); const phi::DDim &dim_dout = dout.dims(); const phi::DDim stride_x = dim_x.size() == 0 ? phi::DDim(dim_x) : phi::stride(dim_x); const phi::DDim stride_y = dim_y.size() == 0 ? phi::DDim(dim_y) : phi::stride(dim_y); const phi::DDim stride_dout = dim_dout.size() == 0 ? phi::DDim(dim_dout) : phi::stride(dim_dout); const int64_t *p_stride_x = nullptr; const int64_t *p_stride_y = nullptr; const int64_t *p_stride_dout = nullptr; const int64_t *p_shape_y = nullptr; #if defined(__NVCC__) || defined(__HIPCC__) thrust::device_vector d_stride_x(ndims); thrust::device_vector d_stride_y(ndims); thrust::device_vector d_stride_dout(ndims); thrust::device_vector d_shape_y(ndims); thrust::copy(stride_x.Get(), stride_x.Get() + ndims, d_stride_x.begin()); thrust::copy(stride_y.Get(), stride_y.Get() + ndims, d_stride_y.begin()); thrust::copy( stride_dout.Get(), stride_dout.Get() + ndims, d_stride_dout.begin()); thrust::copy(dim_y.Get(), dim_y.Get() + ndims, d_shape_y.begin()); p_stride_x = thrust::raw_pointer_cast(d_stride_x.data()); p_stride_y = thrust::raw_pointer_cast(d_stride_y.data()); p_stride_dout = thrust::raw_pointer_cast(d_stride_dout.data()); p_shape_y = thrust::raw_pointer_cast(d_shape_y.data()); #else p_stride_x = stride_x.Get(); p_stride_y = stride_y.Get(); p_stride_dout = stride_dout.Get(); p_shape_y = dim_y.Get(); #endif // dout_x: dout * kron(ones(X), Y) re-aranged in shape (numel_x, numel_y) // dout_y: dout * kron(X, ones(Y)) re-aranged in shaoe (numel_y, numel_x) DenseTensor dout_x; T *p_dout_x = nullptr; if (dx) { dout_x.Resize({numel_x, numel_y}); dev_ctx.template Alloc(&dout_x); p_dout_x = dout_x.data(); } DenseTensor dout_y; T *p_dout_y = nullptr; if (dy) { dout_y.Resize({numel_y, numel_x}); dev_ctx.template Alloc(&dout_y); p_dout_y = dout_y.data(); } funcs::ForRange for_range(dev_ctx, numel); KronGradElemFunctor func(dout.data(), x.data(), y.data(), p_dout_x, p_dout_y, p_stride_dout, p_stride_x, p_stride_y, p_shape_y, numel_x, numel_y, ndims); for_range(func); // reduce_sum along aixs 1 #if defined(__NVCC__) || defined(__HIPCC__) auto stream = dev_ctx.stream(); // it is a cuda device_context if (dx) { funcs::ReduceKernel>( dev_ctx, dout_x, dx, kps::IdentityFunctor(), {1}); } if (dy) { funcs::ReduceKernel>( dev_ctx, dout_y, dy, kps::IdentityFunctor(), {1}); } #else auto *place = dev_ctx.eigen_device(); Eigen::array reduce_dim = {1}; if (dx) { auto eigen_dout_x = EigenMatrix::Reshape(dout_x, 1); auto eigen_vec_dx = EigenVector::Flatten(*dx); eigen_vec_dx.device(*place) = eigen_dout_x.sum(reduce_dim); } if (dy) { auto eigen_dout_y = EigenMatrix::Reshape(dout_y, 1); auto eigen_vec_dy = EigenVector::Flatten(*dy); eigen_vec_dy.device(*place) = eigen_dout_y.sum(reduce_dim); } #endif } }; template void KronGradKernel(const Context &ctx, const DenseTensor &x, const DenseTensor &y, const DenseTensor &out_grad, DenseTensor *x_grad, DenseTensor *y_grad) { if (x_grad) { ctx.template Alloc(x_grad); } if (y_grad) { ctx.template Alloc(y_grad); } int ndims = out_grad.dims().size(); DenseTensor xx = UnsqueezeTo(x, ndims); DenseTensor yy = UnsqueezeTo(y, ndims); DenseTensor *pdxx = nullptr; DenseTensor *pdyy = nullptr; DenseTensor dxx; DenseTensor dyy; if (x_grad) { dxx = UnsqueezeTo(*x_grad, ndims); pdxx = &dxx; } if (y_grad) { dyy = UnsqueezeTo(*y_grad, ndims); pdyy = &dyy; } KronGradOpFunctor func; func(ctx, out_grad, xx, yy, pdxx, pdyy); } } // namespace phi