/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/transform.h" #include "paddle/pten/backends/all_context.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/kernels/empty_kernel.h" #if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/function_traits.h" #include "paddle/pten/kernels/primitive/kernel_primitives.h" namespace kps = pten::kps; #endif namespace pten { enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; /* Packing scalar type T(float, int etc.) into Array type for supporting multiple-output feature in elementwise system.*/ template using ConditionalT = typename std::conditional_t>; namespace funcs { using DDim = pten::framework::DDim; template struct ElemwiseGradNoBroadcast { const T *x_; const T *y_; const Tout *out_; const Tout *dout_; HOSTDEVICE void operator()(size_t i) { if (dx_ != nullptr) { dx_[i] = dx_op_(x_[i], y_[i], out_[i], dout_[i]); } if (dy_ != nullptr) { dy_[i] = dy_op_(x_[i], y_[i], out_[i], dout_[i]); } } DX_OP dx_op_; DY_OP dy_op_; T *dx_; T *dy_; }; template class RowwiseTransformIterator; template class MidWiseTransformIterator; // NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17 template class RowwiseTransformIterator : public std::iterator { public: RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {} RowwiseTransformIterator &operator++() { ++i_; if (UNLIKELY(i_ == n_)) { i_ = 0; } return *this; } RowwiseTransformIterator &operator+(int n) { while (n-- > 0) { ++i_; if (UNLIKELY(i_ == n_)) { i_ = 0; } } return *this; } bool operator==(const RowwiseTransformIterator &rhs) const { return (ptr_ + i_) == &(*rhs); } bool operator!=(const RowwiseTransformIterator &rhs) const { return (ptr_ + i_) != &(*rhs); } const T &operator*() { return ptr_[i_]; } private: const T *ptr_; int i_; int64_t n_; }; template class MidWiseTransformIterator : public std::iterator { public: MidWiseTransformIterator(const T *ptr, int n, int post) : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} MidWiseTransformIterator &operator++() { ++j_; if (UNLIKELY(j_ == post_)) { ++i_; j_ = 0; if (UNLIKELY(i_ == n_)) { i_ = 0; } } return *this; } MidWiseTransformIterator &operator+(int n) { while (n-- > 0) { ++j_; if (UNLIKELY(j_ == post_)) { ++i_; j_ = 0; if (UNLIKELY(i_ == n_)) { i_ = 0; } } } return *this; } bool operator==(const MidWiseTransformIterator &rhs) const { return (ptr_ + i_) == &(*rhs); } bool operator!=(const MidWiseTransformIterator &rhs) const { return (ptr_ + i_) != &(*rhs); } const T &operator*() { return ptr_[i_]; } private: const T *ptr_; int64_t i_; int64_t j_; int64_t n_; int64_t post_; }; #if defined(__NVCC__) || defined(__HIPCC__) template class RowwiseTransformIterator : public thrust::iterator_adaptor, const T *> { public: typedef thrust::iterator_adaptor, const T *> super_t; HOSTDEVICE RowwiseTransformIterator(const T *x, int n) : super_t(x), begin_(x), n_(n) {} friend class thrust::iterator_core_access; private: unsigned int n_; const T *begin_; HOSTDEVICE typename super_t::reference dereference() const { return *(begin_ + (this->base() - begin_) % n_); } }; template class MidWiseTransformIterator : public thrust::iterator_adaptor, const T *> { public: typedef thrust::iterator_adaptor, const T *> super_t; HOSTDEVICE MidWiseTransformIterator(const T *x, int n, int post) : super_t(x), begin_(x), n_(n), post_(post) {} friend class thrust::iterator_core_access; private: unsigned int post_; unsigned int n_; const T *begin_; HOSTDEVICE typename super_t::reference dereference() const { return *(begin_ + (((this->base() - begin_) / post_) % n_)); } }; #endif template class TransformFunctor { public: TransformFunctor(const DenseTensor &x, const DenseTensor &y, DenseTensor *z, const DeviceContext &ctx, Functor func, const bool is_xsize_larger = true) : x_(x.data()), y_(y.data()), z_(z->mutable_data(ctx.GetPlace())), nx_(x.numel()), ctx_(ctx), func_(func), is_xsize_larger_(is_xsize_larger) { if (is_xsize_larger_ == false) { nx_ = y.numel(); } } inline void Run() const { paddle::platform::Transform trans; trans(ctx_, x_, x_ + nx_, y_, z_, func_); } inline void RunRowWise(int n, int pre) const { paddle::platform::Transform trans; if (is_xsize_larger_) { trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator(y_, n), z_, func_); } else { trans(ctx_, y_, y_ + nx_, RowwiseTransformIterator(x_, n), z_, func_); } } inline void RunMidWise(int n, int pre, int post) const { paddle::platform::Transform trans; if (is_xsize_larger_) { trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator(y_, n, post), z_, func_); } else { trans(ctx_, y_, y_ + nx_, MidWiseTransformIterator(x_, n, post), z_, func_); } } private: const T *x_; const T *y_; OutType *z_; int64_t nx_; const DeviceContext &ctx_; Functor func_; bool is_xsize_larger_; }; inline DDim trim_trailing_singular_dims(const DDim &dims) { // Remove trailing dimensions of size 1 for y auto actual_dims_size = dims.size(); for (; actual_dims_size != 0; --actual_dims_size) { if (dims[actual_dims_size - 1] != 1) break; } if (actual_dims_size == dims.size()) return dims; std::vector trim_dims; trim_dims.resize(actual_dims_size); for (int i = 0; i < actual_dims_size; ++i) { trim_dims[i] = dims[i]; } if (trim_dims.size() == 0) { return DDim(pten::framework::make_dim()); } DDim actual_dims = pten::framework::make_ddim(trim_dims); return actual_dims; } /* * Out = X ⊙ Y * If Y's shape does not match X' shape, they will be reshaped. * For example: * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1 * pre=2, n=3*4, post=5 * x.shape(2, 12, 5) * y.shape(1, 12, 1).broadcast(2, 12, 5) * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5) * pre=2*3, n=4*5, post=1 * x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1) * * New parameter: *is_run_common_broadcast* is a flag to record whether to run * common broadcast code. */ inline void get_mid_dims(const DDim &x_dims, const DDim &y_dims, const int axis, int *pre, int *n, int *post, int *is_run_common_broadcast) { *pre = 1; *n = 1; *post = 1; *is_run_common_broadcast = 0; for (int i = 0; i < axis; ++i) { (*pre) *= x_dims[i]; } for (int i = 0; i < y_dims.size(); ++i) { if (x_dims[i + axis] != y_dims[i]) { PADDLE_ENFORCE_EQ(y_dims[i] == 1 || x_dims[i + axis] == 1, true, paddle::platform::errors::InvalidArgument( "Broadcast dimension mismatch. Operands " "could not be broadcast together with the shape of " "X = [%s] and the shape of Y = [%s]. Received [%d] " "in X is not equal to [%d] in Y.", x_dims, y_dims, x_dims[i + axis], y_dims[i])); *is_run_common_broadcast = 1; return; } (*n) *= y_dims[i]; } for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) { (*post) *= x_dims[i]; } } // for broadcast backwards static inline std::vector GetReduceDim(const paddle::framework::DDim &in, const paddle::framework::DDim &out, int axis) { axis = (axis == -1 ? std::abs(static_cast(out.size() - in.size())) : axis); std::vector dims; for (int i = 0; i < axis; ++i) { dims.push_back(i); } for (int i = 0; i < in.size(); ++i) { if (out[i + axis] != in[i]) { dims.push_back(i + axis); } } for (int i = axis + in.size(); i < out.size(); ++i) { dims.push_back(i); } return dims; } template static inline void GetDoubleGradSafeTensor(const DeviceContext &dev_ctx, const DenseTensor &x, const DenseTensor *ddx, DenseTensor *ddx_safe) { if (ddx) { *ddx_safe = *ddx; } else { auto meta = pten::DenseTensorMeta(x.dtype(), x.dims(), x.layout()); *ddx_safe = pten::Empty(dev_ctx, std::move(meta)); ddx_safe->mutable_data(dev_ctx.GetPlace()); paddle::operators::math::SetConstant set_zero; set_zero(dev_ctx, ddx_safe, static_cast(0)); } } template void ElemwiseGradComputeNoBroadcast(const DeviceContext &dev_ctx, const DDim &x_dim, const DDim &y_dim, const DenseTensor &x, const DenseTensor &y, const DenseTensor &out, const DenseTensor &dout, int axis, DenseTensor *dx, DenseTensor *dy, DX_OP dx_op, DY_OP dy_op) { size_t N = static_cast(pten::framework::product(x_dim)); paddle::platform::ForRange for_range(dev_ctx, N); for_range(ElemwiseGradNoBroadcast{ x.data(), y.data(), out.data(), dout.data(), dx_op, dy_op, dx == nullptr ? nullptr : dx->mutable_data(dev_ctx.GetPlace()), dy == nullptr ? nullptr : dy->mutable_data(dev_ctx.GetPlace())}); } inline void ElementwiseGradPreProcess(const DenseTensor &dout, DenseTensor *dx) { if (dx != nullptr) { dx->set_lod(dout.lod()); } } #if defined(__NVCC__) || defined(__HIPCC__) template int GetVectorizedSizeForTensors(const std::vector &ins, const std::vector &outs) { int vec_size = 4; for (auto iter = ins.begin(); iter != ins.end(); ++iter) { vec_size = std::min( vec_size, paddle::platform::GetVectorizedSize((*iter)->data())); } for (auto iter = outs.begin(); iter != outs.end(); ++iter) { vec_size = std::min( vec_size, paddle::platform::GetVectorizedSize((*iter)->data())); } return vec_size; } template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result); }; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result) { kps::ElementwiseAny( result, args, func); } }; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result) { kps::ElementwiseUnary( result, args[0], func); } }; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result) { kps::ElementwiseBinary( result, args[0], args[1], func); } }; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result) { kps::ElementwiseTernary( result, args[0], args[1], args[2], func); } }; template struct ElementwiseWriteDataCaller { __device__ __forceinline__ void operator()( pten::framework::Array<_ptr_ OutT *, NumOuts> outs, ConditionalT src[VecSize], int block_offset, int num) { OutT dst[NumOuts][VecSize]; #pragma unroll for (int i = 0; i < VecSize; ++i) { #pragma unroll for (int j = 0; j < NumOuts; ++j) { dst[j][i] = (src[i])[j]; } } #pragma unroll for (int i = 0; i < NumOuts; ++i) { kps::WriteData( outs[i] + block_offset, dst[i], num); } } }; template struct ElementwiseWriteDataCaller { __device__ __forceinline__ void operator()( pten::framework::Array<_ptr_ OutT *, 1> outs, OutT src[VecSize], int block_offset, int num) { kps::WriteData( outs[0] + block_offset, src, num); } }; template __device__ void VectorizedElementwiseKernelImpl( const pten::framework::Array &in, pten::framework::Array<_ptr_ OutT *, NumOuts> outs, int num, int data_offset, Functor func) { InT args[Arity][VecSize]; ConditionalT result[VecSize]; #pragma unroll for (int i = 0; i < Arity; i++) { kps::Init(args[i], static_cast(1.0f)); kps::ReadData( args[i], in[i] + data_offset, num); } constexpr bool kCallElementwiseAny = paddle::platform::FunctionTraits::has_pointer_args; ElementwisePrimitiveCaller, VecSize, Functor, Arity, kCallElementwiseAny>()(func, args, result); ElementwiseWriteDataCaller()( outs, result, data_offset, num); } template __global__ void VectorizedElementwiseKernel( pten::framework::Array ins, pten::framework::Array<_ptr_ OutT *, NumOuts> outs, int size, int main_offset, Functor func) { int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; for (; data_offset < main_offset; data_offset += stride) { VectorizedElementwiseKernelImpl( ins, outs, VecSize * BLOCK_NUM_X, data_offset, func); } int num = size - data_offset; if (num > 0) { VectorizedElementwiseKernelImpl(ins, outs, num, data_offset, func); } } template void ElementwiseCudaKernel(const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func) { auto numel = ins[0]->numel(); pten::framework::Array ins_data; pten::framework::Array<_ptr_ OutT *, NumOuts> outs_data; for (int i = 0; i < Arity; ++i) { ins_data[i] = ins[i]->data(); } for (int i = 0; i < NumOuts; ++i) { outs_data[i] = (*outs)[i]->mutable_data(ctx.GetPlace()); } #ifdef PADDLE_WITH_XPU2 int block_size = 64; int grid_size = 8; auto stream = ctx.x_context()->xpu_stream; int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; VectorizedElementwiseKernel<<>>( ins_data, outs_data, numel, main_offset, func); #else auto gpu_config = GetGpuLaunchConfig1D(ctx, numel, VecSize); int main_offset = (numel / (VecSize * gpu_config.GetBlockSize())) * VecSize * gpu_config.GetBlockSize(); auto stream = ctx.stream(); VectorizedElementwiseKernel<<< gpu_config.block_per_grid, gpu_config.thread_per_block, 0, stream>>>(ins_data, outs_data, numel, main_offset, func); #endif } template void LaunchSameDimsElementwiseCudaKernel( const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func) { using Traits = paddle::platform::FunctionTraits; const int kArity = Traits::has_pointer_args ? static_cast(ET) : Traits::arity; PADDLE_ENFORCE_EQ(ins.size(), kArity, paddle::platform::errors::InvalidArgument( "The number of inputs is expected to be equal to the " "arity of functor. But recieved: the number of inputs " "is %d, the arity of functor is %d.", ins.size(), kArity)); PADDLE_ENFORCE_EQ(outs->size(), NumOuts, paddle::platform::errors::InvalidArgument( "Number of outputs shall equal to number of functions, " "but number of outputs is %d, of functions is %d.", outs->size(), NumOuts)); if (NumOuts > 1) { for (int i = 1; i < NumOuts; ++i) { PADDLE_ENFORCE_EQ( (*outs)[i]->dims(), (*outs)[0]->dims(), paddle::platform::errors::InvalidArgument( "The shape of each output tensor shall be identical yet, " "but %dth output tensor`s shape is not.", i)); } } // calculate the max vec_size for all ins and outs int vec_size = GetVectorizedSizeForTensors(ins, *outs); switch (vec_size) { case 4: ElementwiseCudaKernel( ctx, ins, outs, func); break; case 2: ElementwiseCudaKernel( ctx, ins, outs, func); break; case 1: ElementwiseCudaKernel( ctx, ins, outs, func); break; default: { PADDLE_THROW(paddle::platform::errors::Unimplemented( "Unsupported vectorized size: %d !", vec_size)); break; } } } #endif } // namespace funcs } // namespace pten