/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/transform.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/function_traits.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" namespace kps = phi::kps; #endif namespace phi { enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; /* Packing scalar type T(float, int etc.) into Array type for supporting multiple-output feature in elementwise system.*/ template using ConditionalT = typename std::conditional_t>; namespace funcs { using DDim = phi::DDim; template struct ElemwiseGradNoBroadcast { const T *x_; const T *y_; const Tout *out_; const Tout *dout_; HOSTDEVICE void operator()(size_t i) { if (dx_ != nullptr) { dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); } if (dy_ != nullptr) { dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]); } } DX_OP dx_op_; DY_OP dy_op_; T *dx_; T *dy_; }; template class RowwiseTransformIterator; template class MidWiseTransformIterator; // NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17 template class RowwiseTransformIterator : public std::iterator { public: RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {} RowwiseTransformIterator &operator++() { ++i_; if (UNLIKELY(i_ == n_)) { i_ = 0; } return *this; } RowwiseTransformIterator &operator+(int n) { while (n-- > 0) { ++i_; if (UNLIKELY(i_ == n_)) { i_ = 0; } } return *this; } bool operator==(const RowwiseTransformIterator &rhs) const { return (ptr_ + i_) == &(*rhs); } bool operator!=(const RowwiseTransformIterator &rhs) const { return (ptr_ + i_) != &(*rhs); } const T &operator*() { return ptr_[i_]; } private: const T *ptr_; int i_; int64_t n_; }; template class MidWiseTransformIterator : public std::iterator { public: MidWiseTransformIterator(const T *ptr, int n, int post) : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} MidWiseTransformIterator &operator++() { ++j_; if (UNLIKELY(j_ == post_)) { ++i_; j_ = 0; if (UNLIKELY(i_ == n_)) { i_ = 0; } } return *this; } MidWiseTransformIterator &operator+(int n) { while (n-- > 0) { ++j_; if (UNLIKELY(j_ == post_)) { ++i_; j_ = 0; if (UNLIKELY(i_ == n_)) { i_ = 0; } } } return *this; } bool operator==(const MidWiseTransformIterator &rhs) const { return (ptr_ + i_) == &(*rhs); } bool operator!=(const MidWiseTransformIterator &rhs) const { return (ptr_ + i_) != &(*rhs); } const T &operator*() { return ptr_[i_]; } private: const T *ptr_; int64_t i_; int64_t j_; int64_t n_; int64_t post_; }; #if defined(__NVCC__) || defined(__HIPCC__) template class RowwiseTransformIterator : public thrust::iterator_adaptor, const T *> { public: typedef thrust::iterator_adaptor, const T *> super_t; HOSTDEVICE RowwiseTransformIterator(const T *x, int n) : super_t(x), begin_(x), n_(n) {} friend class thrust::iterator_core_access; private: unsigned int n_; const T *begin_; HOSTDEVICE typename super_t::reference dereference() const { return *(begin_ + (this->base() - begin_) % n_); } }; template class MidWiseTransformIterator : public thrust::iterator_adaptor, const T *> { public: typedef thrust::iterator_adaptor, const T *> super_t; HOSTDEVICE MidWiseTransformIterator(const T *x, int n, int post) : super_t(x), begin_(x), n_(n), post_(post) {} friend class thrust::iterator_core_access; private: unsigned int post_; unsigned int n_; const T *begin_; HOSTDEVICE typename super_t::reference dereference() const { return *(begin_ + (((this->base() - begin_) / post_) % n_)); } }; #endif template class TransformFunctor { public: TransformFunctor(const DenseTensor &x, const DenseTensor &y, DenseTensor *z, const DeviceContext &ctx, Functor func, const bool is_xsize_larger = true) : x_(x.data()), y_(y.data()), z_(ctx.template Alloc(z)), nx_(x.numel()), ctx_(ctx), func_(func), is_xsize_larger_(is_xsize_larger) { if (is_xsize_larger_ == false) { nx_ = y.numel(); } } inline void Run() const { paddle::platform::Transform trans; trans(ctx_, x_, x_ + nx_, y_, z_, func_); } inline void RunRowWise(int n, int pre) const { paddle::platform::Transform trans; if (is_xsize_larger_) { trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator(y_, n), z_, func_); } else { trans(ctx_, y_, y_ + nx_, RowwiseTransformIterator(x_, n), z_, func_); } } inline void RunMidWise(int n, int pre, int post) const { paddle::platform::Transform trans; if (is_xsize_larger_) { trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator(y_, n, post), z_, func_); } else { trans(ctx_, y_, y_ + nx_, MidWiseTransformIterator(x_, n, post), z_, func_); } } private: const T *x_; const T *y_; OutType *z_; int64_t nx_; const DeviceContext &ctx_; Functor func_; bool is_xsize_larger_; }; inline DDim trim_trailing_singular_dims(const DDim &dims) { // Remove trailing dimensions of size 1 for y auto actual_dims_size = dims.size(); for (; actual_dims_size != 0; --actual_dims_size) { if (dims[actual_dims_size - 1] != 1) break; } if (actual_dims_size == dims.size()) return dims; std::vector trim_dims; trim_dims.resize(actual_dims_size); for (int i = 0; i < actual_dims_size; ++i) { trim_dims[i] = dims[i]; } if (trim_dims.size() == 0) { return DDim(phi::make_dim()); } DDim actual_dims = phi::make_ddim(trim_dims); return actual_dims; } /* * Out = X ⊙ Y * If Y's shape does not match X' shape, they will be reshaped. * For example: * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 * pre=2, n=3*4, post=5 * x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5) * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) * pre=2*3, n=4*5, post=1 * x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1) * * New parameter: *is_run_common_broadcast* is a flag to record whether to run * common broadcast code. */ inline void get_mid_dims(const DDim &x_dims, const DDim &y_dims, const int axis, int *pre, int *n, int *post, int *is_run_common_broadcast) { *pre = 1; *n = 1; *post = 1; *is_run_common_broadcast = 0; for (int i = 0; i < axis; ++i) { (*pre) *= x_dims[i]; } for (int i = 0; i < y_dims.size(); ++i) { if (x_dims[i + axis] != y_dims[i]) { PADDLE_ENFORCE_EQ(y_dims[i] == 1 || x_dims[i + axis] == 1, true, phi::errors::InvalidArgument( "Broadcast dimension mismatch. Operands " "could not be broadcast together with the shape of " "X = [%s] and the shape of Y = [%s]. Received [%d] " "in X is not equal to [%d] in Y.", x_dims, y_dims, x_dims[i + axis], y_dims[i])); *is_run_common_broadcast = 1; return; } (*n) *= y_dims[i]; } for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { (*post) *= x_dims[i]; } } // for broadcast backwards static inline std::vector GetReduceDim(const DDim &in, const DDim &out, int axis) { axis = (axis == -1 ? std::abs(static_cast(out.size() - in.size())) : axis); std::vector dims; for (int i = 0; i < axis; ++i) { dims.push_back(i); } for (int i = 0; i < in.size(); ++i) { if (out[i + axis] != in[i]) { dims.push_back(i + axis); } } for (int i = axis + in.size(); i < out.size(); ++i) { dims.push_back(i); } return dims; } template static inline void GetDoubleGradSafeTensor(const DeviceContext &dev_ctx, const DenseTensor &x, const DenseTensor *ddx, DenseTensor *ddx_safe) { if (ddx) { *ddx_safe = *ddx; } else { auto meta = phi::DenseTensorMeta(x.dtype(), x.dims(), x.layout()); *ddx_safe = phi::Empty(dev_ctx, std::move(meta)); ddx_safe->mutable_data(dev_ctx.GetPlace()); phi::funcs::SetConstant set_zero; set_zero(dev_ctx, ddx_safe, static_cast(0)); } } template void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx, const DDim &x_dim, const DDim &y_dim, const DenseTensor &x, const DenseTensor &y, const DenseTensor &out, const DenseTensor &dout, int axis, DenseTensor *dx, DenseTensor *dy, DX_OP dx_op, DY_OP dy_op) { size_t N = static_cast(phi::product(x_dim)); paddle::platform::ForRange for_range(dev_ctx, N); for_range(ElemwiseGradNoBroadcast{ x.data(), y.data(), out.data(), dout.data(), dx_op, dy_op, dx == nullptr ? nullptr : dev_ctx.template Alloc(dx), dy == nullptr ? nullptr : dev_ctx.template Alloc(dy)}); } inline void ElementwiseGradPreProcess(const DenseTensor &dout, DenseTensor *dx) { if (dx != nullptr) { dx->set_lod(dout.lod()); } } #if defined(__NVCC__) || defined(__HIPCC__) // static unroller template