/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/function_traits.h" #include "paddle/pten/core/dense_tensor.h" namespace pten { namespace kps = paddle::operators::kernel_primitives; enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; /* Packing scalar type T(float, int etc.) into Array type for supporting multiple-output feature in elementwise system.*/ template using ConditionalT = typename std::conditional_t>; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result); }; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result) { kps::ElementwiseAny( result, args, func); } }; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result) { kps::ElementwiseUnary( result, args[0], func); } }; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result) { kps::ElementwiseBinary( result, args[0], args[1], func); } }; template struct ElementwisePrimitiveCaller { __device__ inline void operator()(Functor func, InT (*args)[VecSize], OutT *result) { kps::ElementwiseTernary( result, args[0], args[1], args[2], func); } }; template struct ElementwiseWriteDataCaller { __device__ __forceinline__ void operator()( paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, ConditionalT src[VecSize], int block_offset, int num) { OutT dst[NumOuts][VecSize]; #pragma unroll for (int i = 0; i < VecSize; ++i) { #pragma unroll for (int j = 0; j < NumOuts; ++j) { dst[j][i] = (src[i])[j]; } } #pragma unroll for (int i = 0; i < NumOuts; ++i) { kps::WriteData( outs[i] + block_offset, dst[i], num); } } }; template struct ElementwiseWriteDataCaller { __device__ __forceinline__ void operator()( paddle::framework::Array<_ptr_ OutT *, 1> outs, OutT src[VecSize], int block_offset, int num) { kps::WriteData( outs[0] + block_offset, src, num); } }; template __device__ void VectorizedElementwiseKernelImpl( const paddle::framework::Array &in, paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, int num, int data_offset, Functor func) { InT args[Arity][VecSize]; ConditionalT result[VecSize]; #pragma unroll for (int i = 0; i < Arity; i++) { kps::Init(args[i], static_cast(1.0f)); kps::ReadData( args[i], in[i] + data_offset, num); } constexpr bool kCallElementwiseAny = paddle::platform::FunctionTraits::has_pointer_args; ElementwisePrimitiveCaller, VecSize, Functor, Arity, kCallElementwiseAny>()(func, args, result); ElementwiseWriteDataCaller()( outs, result, data_offset, num); } template __global__ void VectorizedElementwiseKernel( paddle::framework::Array ins, paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, int size, int main_offset, Functor func) { int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; for (; data_offset < main_offset; data_offset += stride) { VectorizedElementwiseKernelImpl( ins, outs, VecSize * BLOCK_NUM_X, data_offset, func); } int num = size - data_offset; if (num > 0) { VectorizedElementwiseKernelImpl(ins, outs, num, data_offset, func); } } template int GetVectorizedSizeForTensors(const std::vector &ins, const std::vector &outs) { int vec_size = 4; for (auto iter = ins.begin(); iter != ins.end(); ++iter) { vec_size = std::min( vec_size, paddle::platform::GetVectorizedSize((*iter)->data())); } for (auto iter = outs.begin(); iter != outs.end(); ++iter) { vec_size = std::min( vec_size, paddle::platform::GetVectorizedSize((*iter)->data())); } return vec_size; } template void ElementwiseCudaKernel(const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func) { auto numel = ins[0]->numel(); paddle::framework::Array ins_data; paddle::framework::Array<_ptr_ OutT *, NumOuts> outs_data; for (int i = 0; i < Arity; ++i) { ins_data[i] = ins[i]->data(); } for (int i = 0; i < NumOuts; ++i) { outs_data[i] = (*outs)[i]->mutable_data(); } #ifdef PADDLE_WITH_XPU2 int block_size = 64; int grid_size = 8; auto stream = ctx.x_context()->xpu_stream; int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; VectorizedElementwiseKernel<<>>( ins_data, outs_data, numel, main_offset, func); #else auto gpu_config = GetGpuLaunchConfig1D(ctx, numel, VecSize); int main_offset = (numel / (VecSize * gpu_config.GetBlockSize())) * VecSize * gpu_config.GetBlockSize(); auto stream = ctx.stream(); VectorizedElementwiseKernel<<< gpu_config.block_per_grid, gpu_config.thread_per_block, 0, stream>>>(ins_data, outs_data, numel, main_offset, func); #endif } template void LaunchSameDimsElementwiseCudaKernel( const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func) { using Traits = paddle::platform::FunctionTraits; const int kArity = Traits::has_pointer_args ? static_cast(ET) : Traits::arity; PADDLE_ENFORCE_EQ(ins.size(), kArity, paddle::platform::errors::InvalidArgument( "The number of inputs is expected to be equal to the " "arity of functor. But recieved: the number of inputs " "is %d, the arity of functor is %d.", ins.size(), kArity)); PADDLE_ENFORCE_EQ(outs->size(), NumOuts, paddle::platform::errors::InvalidArgument( "Number of outputs shall equal to number of functions, " "but number of outputs is %d, of functions is %d.", outs->size(), NumOuts)); if (NumOuts > 1) { for (int i = 1; i < NumOuts; ++i) { PADDLE_ENFORCE_EQ( (*outs)[i]->dims(), (*outs)[0]->dims(), paddle::platform::errors::InvalidArgument( "The shape of each output tensor shall be identical yet, " "but %dth output tensor`s shape is not.", i)); } } // calculate the max vec_size for all ins and outs int vec_size = GetVectorizedSizeForTensors(ins, *outs); switch (vec_size) { case 4: ElementwiseCudaKernel( ctx, ins, outs, func); break; case 2: ElementwiseCudaKernel( ctx, ins, outs, func); break; case 1: ElementwiseCudaKernel( ctx, ins, outs, func); break; default: { PADDLE_THROW(paddle::platform::errors::Unimplemented( "Unsupported vectorized size: %d !", vec_size)); break; } } } struct DimensionsTransform { using DimVector = std::vector; typedef void (*MergeFunctor)( bool &, std::vector &, DimVector &, int, int); int64_t dim_size; DimVector out_dims; std::vector in_dims; private: // To compensate the lackage of input_tensors` dimension with input variable // 'axis' void InputDimensionsExtend(int N, int axis) { for (auto &in_dim : in_dims) { int64_t in_idx = 0; if (in_dim.size() < dim_size) { DimVector tmp_dim(dim_size, 1); do { if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) { tmp_dim[axis] = in_dim[in_idx]; in_idx++; axis++; } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "The %d-th dimension of input tensor is expected to be equal " "with the %d-th dimension of output tensor %d or 1, but " "recieved %d.", in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx])); } } while (in_idx < in_dim.size()); in_dim.resize(dim_size); std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin()); } else { do { if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) { in_idx++; } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "The %d-th dimension of input tensor is expected to be equal " "with the %d-th dimension of output tensor %d or 1, but " "recieved %d.", in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx])); } } while (in_idx < dim_size); } std::reverse(in_dim.begin(), in_dim.end()); } std::reverse(out_dims.begin(), out_dims.end()); } template __inline__ void MergeDimensions(MergeFunctor merge_func, int N) { auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) { (*vec)[m_idx - 1] = std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1, std::multiplies()); vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1); }; int64_t i = 0; while (i < dim_size) { int cnt = 0; int low_idx = i; bool equal = true; do { merge_func(equal, in_dims, out_dims, i, N); if (equal) { i++; cnt++; } else { break; } } while (i < dim_size); if (cnt > 1) { for (auto &in_dim : in_dims) { VectorReorganise(&in_dim, low_idx, i); } VectorReorganise(&out_dims, low_idx, i); dim_size -= --cnt; i -= cnt; } else if (cnt < 1) { i++; } } } public: explicit DimensionsTransform(const std::vector &ins, const paddle::framework::DDim &dims, int axis) { const int N = ins.size(); dim_size = dims.size(); out_dims = paddle::framework::vectorize(dims); in_dims.resize(N); for (int j = 0; j < N; ++j) { in_dims[j] = paddle::framework::vectorize(ins[j]->dims()); } InputDimensionsExtend(N, axis); auto merge_sequential_dims = [](bool &equal, std::vector &in_dims, DimVector &out, int i, int num) { for (int j = 1; j < num; ++j) { equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false; } }; auto merge_sequential_one_dims = [](bool &equal, std::vector &in_dims, DimVector &out, int i, int num) { equal = in_dims[0][i] == 1; if (equal) { for (int j = 1; j < num; ++j) { equal &= in_dims[j][i] == out[i]; } } }; // To Merge the dimensions of input_tensors while the consequtive // equal-dimensions appears. MergeFunctor merge_ptr = merge_sequential_dims; MergeDimensions(merge_ptr, N); int min_idx = 0; int min_val = std::accumulate( in_dims[0].begin(), in_dims[0].end(), 1, std::multiplies()); for (int j = 1; j < N; ++j) { int temp = std::accumulate( in_dims[j].begin(), in_dims[j].end(), 1, std::multiplies()); min_val = min_val > temp ? temp : min_val; min_idx = min_val == temp ? j : min_idx; } std::swap(in_dims[0], in_dims[min_idx]); // To Merge the dimension of input_tensors while the consequtive // 1-value-dimensions appears. merge_ptr = merge_sequential_one_dims; MergeDimensions(merge_ptr, N); std::swap(in_dims[min_idx], in_dims[0]); } }; template __device__ __forceinline__ void LoadData( T *dst, const _ptr_ T *src, uint32_t block_offset, const kps::details::BroadcastConfig &config, int numel, int num, int need_broadcast) { // numel : whole num of output // num: how many data will be deal with in this time if (need_broadcast) { kps::ReadDataBc( dst, src, block_offset, config, numel); } else { kps::ReadData(dst, src + block_offset, num); } } template __device__ void ElementwiseBroadcastKernelImpl( const paddle::framework::Array &ins, paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, const paddle::framework::Array &use_broadcast, uint32_t numel, const paddle::framework::Array, Arity> &configs, int num, int block_offset, Functor func) { InT args[Arity][VecSize]; ConditionalT result[VecSize]; #pragma unroll for (int i = 0; i < Arity; i++) { kps::Init(args[i], static_cast(1.0f)); LoadData(args[i], ins[i], block_offset, configs[i], numel, num, use_broadcast[i]); } constexpr bool kCallElementwiseAny = paddle::platform::FunctionTraits::has_pointer_args; ElementwisePrimitiveCaller, VecSize, Functor, Arity, kCallElementwiseAny>()(func, args, result); ElementwiseWriteDataCaller()( outs, result, block_offset, num); } template __global__ void ElementwiseBroadcastKernel( paddle::framework::Array ins, paddle::framework::Array<_ptr_ OutT *, NumOuts> outs, paddle::framework::Array use_broadcast, uint32_t numel, paddle::framework::Array, Arity> configs, int main_offset, int tail_tid, Functor func) { int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; #ifdef PADDLE_WITH_XPU2 for (; block_offset < main_offset; block_offset += stride) { ElementwiseBroadcastKernelImpl(ins, outs, use_broadcast, numel, configs, BLOCK_NUM_X * VecSize, block_offset, func); } int num = numel - block_offset; if (num > 0) { ElementwiseBroadcastKernelImpl( ins, outs, use_broadcast, numel, configs, num, block_offset, func); } #else if (block_offset < main_offset) { ElementwiseBroadcastKernelImpl(ins, outs, use_broadcast, numel, configs, BLOCK_NUM_X * VecSize, block_offset, func); } else { ElementwiseBroadcastKernelImpl( ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func); } #endif } template void LaunchKernel(const KPDevice &ctx, const std::vector &ins, std::vector *outs, Functor func, DimensionsTransform merge_dims) { int numel = (*outs)[0]->numel(); paddle::framework::Array, Arity> configs; paddle::framework::Array use_broadcast; paddle::framework::Array ins_data; paddle::framework::Array<_ptr_ OutT *, NumOuts> outs_data; for (int i = 0; i < NumOuts; ++i) { outs_data[i] = (*outs)[i]->mutable_data(); } for (int i = 0; i < Arity; i++) { use_broadcast[i] = (ins[i]->numel() != numel); ins_data[i] = (_ptr_ InT *)(ins[i]->data()); if (use_broadcast[i]) { // get the broadcast config, // if data shape is[m, n], then you should set data_dim = {n, m} // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} configs[i] = kps::details::BroadcastConfig( merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); } } #ifdef PADDLE_WITH_XPU2 const int threads = 64; const int blocks = 8; int main_offset = (numel / (VecSize * threads)) * VecSize * threads; int tail_tid = numel % (VecSize * threads); auto stream = ctx.x_context()->xpu_stream; ElementwiseBroadcastKernel<<>>(ins_data, outs_data, use_broadcast, numel, configs, main_offset, tail_tid, func); #else const int threads = 256; int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; int main_offset = (numel / (VecSize * threads)) * VecSize * threads; int tail_tid = numel % (VecSize * threads); auto stream = ctx.stream(); ElementwiseBroadcastKernel<<>>( ins_data, outs_data, use_broadcast, numel, configs, main_offset, tail_tid, func); #endif } template void LaunchBroadcastKernelForDifferentVecSize( const KPDevice &ctx, const std::vector &ins, std::vector *outs, int axis, Functor func) { const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis); #define CALL_BROADCAST_FOR_DIM_SIZE(rank) \ case rank: { \ LaunchKernel( \ ctx, ins, outs, func, merge_dims); \ } break; switch (merge_dims.dim_size) { CALL_BROADCAST_FOR_DIM_SIZE(1); CALL_BROADCAST_FOR_DIM_SIZE(2); CALL_BROADCAST_FOR_DIM_SIZE(3); CALL_BROADCAST_FOR_DIM_SIZE(4); CALL_BROADCAST_FOR_DIM_SIZE(5); CALL_BROADCAST_FOR_DIM_SIZE(6); CALL_BROADCAST_FOR_DIM_SIZE(7); CALL_BROADCAST_FOR_DIM_SIZE(8); default: { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "The maximum dimension of input tensor is expected to be less than " "%d, but recieved %d.\n", merge_dims.dim_size, paddle::framework::DDim::kMaxRank)); } } #undef CALL_BROADCAST_FOR_DIM_SIZE } template void LaunchBroadcastElementwiseCudaKernel( const KPDevice &ctx, const std::vector &ins, std::vector *outs, int axis, Functor func) { using Traits = paddle::platform::FunctionTraits; const int kArity = Traits::has_pointer_args ? static_cast(ET) : Traits::arity; PADDLE_ENFORCE_EQ(ins.size(), kArity, paddle::platform::errors::InvalidArgument( "The number of inputs is expected to be equal to the " "arity of functor. But recieved: the number of inputs " "is %d, the arity of functor is %d.", ins.size(), kArity)); PADDLE_ENFORCE_LE(kArity, 3, paddle::platform::errors::InvalidArgument( "Currently only broadcast of ternary is supported " "and verified, but received %d.", kArity)); PADDLE_ENFORCE_EQ(outs->size(), NumOuts, paddle::platform::errors::InvalidArgument( "Number of outputs shall equal to number of functions, " "but number of outputs is %d, of functions is %d.", outs->size(), NumOuts)); int in_vec_size = 4; int out_vec_size = 4; if (NumOuts > 1) { for (int i = 0; i < NumOuts; ++i) { PADDLE_ENFORCE_EQ( (*outs)[i]->dims(), (*outs)[0]->dims(), paddle::platform::errors::InvalidArgument( "The shape of each output tensor shall be identical yet, but " "%dth output tensor`s shape is not.", i)); out_vec_size = std::min( paddle::platform::GetVectorizedSize((*outs)[i]->data()), out_vec_size); } } else { out_vec_size = paddle::platform::GetVectorizedSize((*outs)[0]->data()); } for (auto *in : ins) { auto temp_size = paddle::platform::GetVectorizedSize(in->data()); in_vec_size = in->dims() == (*outs)[0]->dims() ? std::min(temp_size, in_vec_size) : in_vec_size; } int vec_size = std::min(out_vec_size, in_vec_size); switch (vec_size) { case 4: { LaunchBroadcastKernelForDifferentVecSize(ctx, ins, outs, axis, func); break; } case 2: { LaunchBroadcastKernelForDifferentVecSize(ctx, ins, outs, axis, func); break; } case 1: { LaunchBroadcastKernelForDifferentVecSize(ctx, ins, outs, axis, func); break; } default: { PADDLE_THROW(paddle::platform::errors::Unimplemented( "Unsupported vectorized size: %d !", vec_size)); break; } } } template void LaunchElementwiseCudaKernel(const KPDevice &ctx, const std::vector &ins, std::vector *outs, int axis, Functor func) { std::vector dims_size; bool no_broadcast_flag = true; for (auto *in : ins) { no_broadcast_flag &= ins[0]->dims() == in->dims(); dims_size.emplace_back(in->dims().size()); } if (no_broadcast_flag) { LaunchSameDimsElementwiseCudaKernel( ctx, ins, outs, func); } else { axis = axis == -1 ? *std::max_element(dims_size.begin(), dims_size.end()) - *std::min_element(dims_size.begin(), dims_size.end()) : axis; LaunchBroadcastElementwiseCudaKernel( ctx, ins, outs, axis, func); } } } // namespace pten