/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/fluid/platform/transform.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/common_shape.h" #include "paddle/phi/kernels/funcs/elementwise_utils.h" #include "paddle/phi/kernels/funcs/math_function.h" #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) #include "paddle/fluid/platform/function_traits.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" #define HOSTDEVICE __host__ __device__ namespace kps = phi::kps; #endif namespace phi { enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; /* Packing scalar type T(float, int etc.) into Array type for supporting multiple-output feature in elementwise system.*/ template using ConditionalT = typename std::conditional_t>; namespace funcs { using DDim = phi::DDim; template class RowwiseTransformIterator; template class MidWiseTransformIterator; // NOTE(dzhwinter): ptrdiff_t in iterator is deperecated in c++17 template class RowwiseTransformIterator : public std::iterator { public: RowwiseTransformIterator(const T *ptr, int n) : ptr_(ptr), i_(0), n_(n) {} RowwiseTransformIterator &operator++() { ++i_; if (UNLIKELY(i_ == n_)) { i_ = 0; } return *this; } RowwiseTransformIterator &operator+(int n) { while (n-- > 0) { ++i_; if (UNLIKELY(i_ == n_)) { i_ = 0; } } return *this; } bool operator==(const RowwiseTransformIterator &rhs) const { return (ptr_ + i_) == &(*rhs); } bool operator!=(const RowwiseTransformIterator &rhs) const { return (ptr_ + i_) != &(*rhs); } const T &operator*() { return ptr_[i_]; } private: const T *ptr_; int i_; int64_t n_; }; template class MidWiseTransformIterator : public std::iterator { public: MidWiseTransformIterator(const T *ptr, int n, int post) : ptr_(ptr), i_(0), j_(0), n_(n), post_(post) {} MidWiseTransformIterator &operator++() { ++j_; if (UNLIKELY(j_ == post_)) { ++i_; j_ = 0; if (UNLIKELY(i_ == n_)) { i_ = 0; } } return *this; } MidWiseTransformIterator &operator+(int n) { while (n-- > 0) { ++j_; if (UNLIKELY(j_ == post_)) { ++i_; j_ = 0; if (UNLIKELY(i_ == n_)) { i_ = 0; } } } return *this; } bool operator==(const MidWiseTransformIterator &rhs) const { return (ptr_ + i_) == &(*rhs); } bool operator!=(const MidWiseTransformIterator &rhs) const { return (ptr_ + i_) != &(*rhs); } const T &operator*() { return ptr_[i_]; } private: const T *ptr_; int64_t i_; int64_t j_; int64_t n_; int64_t post_; }; #if defined(__NVCC__) || defined(__HIPCC__) template class RowwiseTransformIterator : public thrust::iterator_adaptor, const T *> { public: typedef thrust::iterator_adaptor, const T *> super_t; HOSTDEVICE RowwiseTransformIterator(const T *x, int n) : super_t(x), begin_(x), n_(n) {} friend class thrust::iterator_core_access; private: unsigned int n_; const T *begin_; HOSTDEVICE typename super_t::reference dereference() const { return *(begin_ + (this->base() - begin_) % n_); } }; template class MidWiseTransformIterator : public thrust::iterator_adaptor, const T *> { public: typedef thrust::iterator_adaptor, const T *> super_t; HOSTDEVICE MidWiseTransformIterator(const T *x, int n, int post) : super_t(x), begin_(x), n_(n), post_(post) {} friend class thrust::iterator_core_access; private: unsigned int post_; unsigned int n_; const T *begin_; HOSTDEVICE typename super_t::reference dereference() const { return *(begin_ + (((this->base() - begin_) / post_) % n_)); } }; #endif template class TransformFunctor { public: TransformFunctor(const DenseTensor &x, const DenseTensor &y, DenseTensor *z, const DeviceContext &ctx, Functor func, const bool is_xsize_larger = true) : x_(x.data()), y_(y.data()), z_(ctx.template Alloc(z)), nx_(x.numel()), ctx_(ctx), func_(func), is_xsize_larger_(is_xsize_larger) { if (is_xsize_larger_ == false) { nx_ = y.numel(); } } inline void Run() const { paddle::platform::Transform trans; trans(ctx_, x_, x_ + nx_, y_, z_, func_); } inline void RunRowWise(int n, int pre) const { paddle::platform::Transform trans; if (is_xsize_larger_) { trans(ctx_, x_, x_ + nx_, RowwiseTransformIterator(y_, n), z_, func_); } else { trans(ctx_, y_, y_ + nx_, RowwiseTransformIterator(x_, n), z_, func_); } } inline void RunMidWise(int n, int pre, int post) const { paddle::platform::Transform trans; if (is_xsize_larger_) { trans(ctx_, x_, x_ + nx_, MidWiseTransformIterator(y_, n, post), z_, func_); } else { trans(ctx_, y_, y_ + nx_, MidWiseTransformIterator(x_, n, post), z_, func_); } } private: const T *x_; const T *y_; OutType *z_; int64_t nx_; const DeviceContext &ctx_; Functor func_; bool is_xsize_larger_; }; template void CommonForwardBroadcastCPU(const DenseTensor &x, const DenseTensor &y, DenseTensor *z, int *x_dims_array, int *y_dims_array, int *out_dims_array, int max_dim, const CPUContext &ctx, Functor func, const bool is_xsize_larger = true) { std::vector index_array(max_dim, 0); const T *x_data = x.data(); const T *y_data = y.data(); PADDLE_ENFORCE_NOT_NULL( x_data, errors::InvalidArgument("The input X should not be empty.")); PADDLE_ENFORCE_NOT_NULL( y_data, errors::InvalidArgument("The input Y should not be empty.")); OutType *out_data = ctx.Alloc(z); const int out_size = std::accumulate( out_dims_array, out_dims_array + max_dim, 1, std::multiplies()); int x_index, y_index; for (int out_index = 0; out_index < out_size; ++out_index) { x_index = GetElementwiseIndex(x_dims_array, max_dim, index_array.data()); y_index = GetElementwiseIndex(y_dims_array, max_dim, index_array.data()); if (is_xsize_larger) { out_data[out_index] = func(x_data[x_index], y_data[y_index]); } else { out_data[out_index] = func(y_data[y_index], x_data[x_index]); } UpdateElementwiseIndexArray(out_dims_array, max_dim, index_array.data()); } } template void CommonElementwiseBroadcastForward(const CPUContext &dev_ctx, const DenseTensor &x, const DenseTensor &y, DenseTensor *z, const DDim &x_dims, const DDim &y_dims, Functor func, int axis, const bool is_xsize_larger = true) { int max_dim = (std::max)(x_dims.size(), y_dims.size()); axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); PADDLE_ENFORCE_GE( axis, 0, phi::errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); PADDLE_ENFORCE_LE(axis, max_dim, phi::errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", max_dim, axis)); std::vector x_dims_array(max_dim); std::vector y_dims_array(max_dim); std::vector out_dims_array(max_dim); GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(), max_dim, axis); CommonForwardBroadcastCPU(x, y, z, x_dims_array.data(), y_dims_array.data(), out_dims_array.data(), max_dim, dev_ctx, func, is_xsize_larger); } // It is a common CPU implementation to compute binary calculation with the // support of broadcast. Note: // 1. CPU implementation cannot support the case when x needs broadcast, thus // this function need to be called with XxxFunctor and XxxInverseFunctor, // like AddFunctor and InverseAddFunctor. // 2. The corresponding GPU implementation supports all the broadcast cases, // thus there is no need to define and call with XxxInverseFunctor. // TODO(liuyiqun): optimize the CPU implementation to support all broadcast // cases and avoid the need of XxxInverseFunctor. template void ElementwiseCompute(const CPUContext &dev_ctx, const DenseTensor &x, const DenseTensor &y, int axis, Functor func, DenseTensor *z) { dev_ctx.Alloc(z); auto x_dims = x.dims(); auto y_dims = y.dims(); bool is_xsize_larger = true; int max_dim = x_dims.size(); if (x_dims.size() < y_dims.size()) { is_xsize_larger = false; max_dim = y_dims.size(); } TransformFunctor functor( x, y, z, dev_ctx, func, is_xsize_larger); if (x_dims == y_dims) { functor.Run(); return; } axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); PADDLE_ENFORCE_GE( axis, 0, errors::InvalidArgument( "Axis should be great than or equal to 0, but received axis is %d.", axis)); PADDLE_ENFORCE_LE(axis, max_dim, errors::InvalidArgument( "Axis should be less than %d, but received axis is %d.", max_dim, axis)); int pre, n, post, is_run_common_broadcast, axis_trim = 0; if (is_xsize_larger) { auto y_dims_trimed = TrimTrailingSingularDims(y_dims); axis_trim = (y_dims_trimed.size() == 0) ? x_dims.size() : axis; GetMidDims(x_dims, y_dims_trimed, axis_trim, &pre, &n, &post, &is_run_common_broadcast); } else { auto x_dims_trimed = TrimTrailingSingularDims(x_dims); axis_trim = (x_dims_trimed.size() == 0) ? y_dims.size() : axis; GetMidDims(y_dims, x_dims_trimed, axis_trim, &pre, &n, &post, &is_run_common_broadcast); } // special case for common implementation. // case 1: x=[2,3,1,5], y=[2,1,4,1] // case 2: x=[2,3,4], y=[1,1,4] if (is_run_common_broadcast == 1) { CommonElementwiseBroadcastForward( dev_ctx, x, y, z, x_dims, y_dims, func, axis, is_xsize_larger); return; } if (post == 1) { functor.RunRowWise(n, pre); return; } else { functor.RunMidWise(n, pre, post); return; } } // for broadcast backwards static inline std::vector GetReduceDim(const DDim &in, const DDim &out, int axis) { axis = (axis == -1 ? std::abs(static_cast(out.size() - in.size())) : axis); std::vector dims; for (int i = 0; i < axis; ++i) { dims.push_back(i); } for (int i = 0; i < in.size(); ++i) { if (out[i + axis] != in[i]) { dims.push_back(i + axis); } } for (int i = axis + in.size(); i < out.size(); ++i) { dims.push_back(i); } return dims; } template static inline void GetDoubleGradSafeTensor(const DeviceContext &dev_ctx, const DenseTensor &x, const DenseTensor *ddx, DenseTensor *ddx_safe) { if (ddx) { *ddx_safe = *ddx; } else { auto meta = phi::DenseTensorMeta(x.dtype(), x.dims(), x.layout()); *ddx_safe = phi::Empty(dev_ctx, std::move(meta)); ddx_safe->mutable_data(dev_ctx.GetPlace()); SetConstant set_zero; set_zero(dev_ctx, ddx_safe, static_cast(0)); } } inline void ElementwiseGradPreProcess(const DenseTensor &dout, DenseTensor *dx) { if (dx != nullptr) { dx->set_lod(dout.lod()); } } #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) // static unroller template