// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.1 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.1 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" namespace paddle { namespace operators { namespace kps = paddle::operators::kernel_primitives; struct DimensionsTransform { using DimVector = std::vector; typedef void (*MergeFunctor)(bool &, std::vector &, DimVector &, int, int); int64_t dim_size; DimVector out_dims; std::vector in_dims; private: // To compensate the lackage of input_tensors` dimension with input variable // 'axis' void InputDimensionsExtend(int N, int axis) { for (auto &in_dim : in_dims) { int64_t in_idx = 0; if (in_dim.size() < dim_size) { DimVector tmp_dim(dim_size, 1); do { if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) { tmp_dim[axis] = in_dim[in_idx]; in_idx++; axis++; } else { PADDLE_THROW(platform::errors::InvalidArgument( "The %d-th dimension of input tensor is expected to be equal " "with the %d-th dimension of output tensor %d or 1, but " "recieved %d.", in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx])); } } while (in_idx < in_dim.size()); in_dim.resize(dim_size); std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin()); } else { do { if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) { in_idx++; } else { PADDLE_THROW(platform::errors::InvalidArgument( "The %d-th dimension of input tensor is expected to be equal " "with the %d-th dimension of output tensor %d or 1, but " "recieved %d.", in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx])); } } while (in_idx < dim_size); } std::reverse(in_dim.begin(), in_dim.end()); } std::reverse(out_dims.begin(), out_dims.end()); } template __inline__ void MergeDimensions(MergeFunctor merge_func, int N) { auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) { (*vec)[m_idx - 1] = std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1, std::multiplies()); vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1); }; int64_t i = 0; while (i < dim_size) { int cnt = 0; int low_idx = i; bool equal = true; do { merge_func(equal, in_dims, out_dims, i, N); if (equal) { i++; cnt++; } else { break; } } while (i < dim_size); if (cnt > 1) { for (auto &in_dim : in_dims) { VectorReorganise(&in_dim, low_idx, i); } VectorReorganise(&out_dims, low_idx, i); dim_size -= --cnt; i -= cnt; } else if (cnt < 1) { i++; } } } public: explicit DimensionsTransform( const std::vector &ins, const framework::DDim &dims, int axis) { const int N = ins.size(); dim_size = dims.size(); out_dims = framework::vectorize(dims); in_dims.resize(N); for (int j = 0; j < N; ++j) { in_dims[j] = framework::vectorize(ins[j]->dims()); } InputDimensionsExtend(N, axis); auto merge_sequential_dims = [](bool &equal, std::vector &in_dims, DimVector &out, int i, int num) { for (int j = 1; j < num; ++j) { equal = (in_dims[0][i] == in_dims[j][i]) ? true : false; } }; auto merge_sequential_one_dims = [](bool &equal, std::vector &in_dims, DimVector &out, int i, int num) { equal = in_dims[0][i] == 1; if (equal) { for (int j = 1; j < num; ++j) { equal = in_dims[j][i] == out[i]; } } }; // To Merge the dimensions of input_tensors while the consequtive // equal-dimensions appears. MergeFunctor merge_ptr = merge_sequential_dims; MergeDimensions(merge_ptr, N); int min_idx = 0; int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1, std::multiplies()); for (int j = 1; j < N; ++j) { int temp = std::accumulate(in_dims[j].begin(), in_dims[j].end(), 1, std::multiplies()); min_val = min_val > temp ? temp : min_val; min_idx = min_val == temp ? j : min_idx; } std::swap(in_dims[0], in_dims[min_idx]); // To Merge the dimension of input_tensors while the consequtive // 1-value-dimensions appears. merge_ptr = merge_sequential_one_dims; MergeDimensions(merge_ptr, N); std::swap(in_dims[min_idx], in_dims[0]); } }; template __device__ __forceinline__ void LoadData( T *dst, const T *__restrict__ src, uint32_t block_offset, const kps::details::BroadcastConfig &config, int numel, int num, bool need_broadcast) { // numel : whole num of output // num: how many data will be deal with in this time if (need_broadcast) { kps::ReadDataBc(dst, src, block_offset, config, numel); } else { kps::ReadData(dst, src + block_offset, num); } } template __device__ void DealSegment( const framework::Array &ins, OutT *out, const framework::Array &use_broadcast, uint32_t numel, const framework::Array, Arity> &configs, int num, Functor func) { InT args[Arity][VecSize]; OutT result[VecSize]; int block_offset = blockIdx.x * blockDim.x * VecSize; #pragma unroll for (int i = 0; i < Arity; i++) { kps::Init(args[i], static_cast(1.0f)); LoadData(args[i], ins[i], block_offset, configs[i], numel, num, use_broadcast[i]); } const bool kCallElementwiseAny = platform::FunctionTraits::has_pointer_args; ElementwisePrimitiveCaller()(func, args, result); kps::WriteData(out + block_offset, result, num); } template __global__ void BroadcastKernel( framework::Array ins, OutT *out, framework::Array use_broadcast, uint32_t numel, framework::Array, Arity> configs, int main_tid, int tail_tid, Functor func) { int block_offset = blockIdx.x * blockDim.x * VecSize; // data offset of this block if (blockIdx.x < main_tid) { int num = blockDim.x * VecSize; // blockIdx.x < main_tid DealSegment( ins, out, use_broadcast, numel, configs, num, func); } else { // reminder int num = tail_tid; DealSegment( ins, out, use_broadcast, numel, configs, num, func); } } template void LaunchKernel(const platform::CUDADeviceContext &ctx, const std::vector &ins, framework::Tensor *out, Functor func, DimensionsTransform merge_dims) { int numel = out->numel(); const int threads = 256; int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; int main_tid = numel / (VecSize * threads); int tail_tid = numel % (VecSize * threads); auto stream = ctx.stream(); OutT *out_data = out->data(); framework::Array, Arity> configs; framework::Array use_broadcast; framework::Array ins_data; for (int i = 0; i < Arity; i++) { use_broadcast[i] = (ins[i]->numel() != numel); ins_data[i] = ins[i]->data(); if (use_broadcast[i]) { // get the broadcast config, // if data shape is[m, n], then you should set data_dim = {n, m} // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} configs[i] = kps::details::BroadcastConfig( merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size); } } BroadcastKernel<<>>( ins_data, out_data, use_broadcast, numel, configs, main_tid, tail_tid, func); } template void LaunchBroadcastKernelForDifferentVecSize( const platform::CUDADeviceContext &ctx, const std::vector &ins, framework::Tensor *out, int axis, Functor func) { const auto merge_dims = DimensionsTransform(ins, out->dims(), axis); #define CALL_BROADCAST_FOR_DIM_SIZE(rank) \ case rank: { \ LaunchKernel(ctx, ins, out, \ func, merge_dims); \ } break; switch (merge_dims.dim_size) { CALL_BROADCAST_FOR_DIM_SIZE(1); CALL_BROADCAST_FOR_DIM_SIZE(2); CALL_BROADCAST_FOR_DIM_SIZE(3); CALL_BROADCAST_FOR_DIM_SIZE(4); CALL_BROADCAST_FOR_DIM_SIZE(5); CALL_BROADCAST_FOR_DIM_SIZE(6); CALL_BROADCAST_FOR_DIM_SIZE(7); CALL_BROADCAST_FOR_DIM_SIZE(8); default: { PADDLE_THROW(platform::errors::InvalidArgument( "The maximum dimension of input tensor is expected to be less than " "%d, but recieved %d.\n", merge_dims.dim_size, framework::DDim::kMaxRank)); } } #undef CALL_BROADCAST_FOR_DIM_SIZE } template void LaunchBroadcastElementwiseCudaKernel( const platform::CUDADeviceContext &ctx, const std::vector &ins, std::vector *outs, int axis, Functor func) { using Traits = platform::FunctionTraits; const int kArity = Traits::has_pointer_args ? static_cast(ET) : Traits::arity; PADDLE_ENFORCE_EQ(ins.size(), kArity, platform::errors::InvalidArgument( "The number of inputs is expected to be equal to the " "arity of functor. But recieved: the number of inputs " "is %d, the arity of functor is %d.", ins.size(), kArity)); PADDLE_ENFORCE_EQ(kArity, 2, platform::errors::InvalidArgument( "Currently only broadcast of binary is supported and " "verified, but received %d.", kArity)); int in_vec_size = 4; framework::Tensor *out = (*outs)[0]; for (auto *in : ins) { auto temp_size = platform::GetVectorizedSize(in->data()); in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size) : in_vec_size; } int out_vec_size = platform::GetVectorizedSize(out->data()); int vec_size = std::min(out_vec_size, in_vec_size); switch (vec_size) { case 4: { LaunchBroadcastKernelForDifferentVecSize( ctx, ins, out, axis, func); break; } case 2: { LaunchBroadcastKernelForDifferentVecSize( ctx, ins, out, axis, func); break; } case 1: { LaunchBroadcastKernelForDifferentVecSize( ctx, ins, out, axis, func); break; } default: { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported vectorized size: %d !", vec_size)); break; } } } template void LaunchElementwiseCudaKernel( const platform::CUDADeviceContext &cuda_ctx, const std::vector &ins, std::vector *outs, int axis, Functor func) { std::vector dims_size; bool no_broadcast_flag = true; for (auto *in : ins) { no_broadcast_flag = ins[0]->dims() == in->dims(); dims_size.emplace_back(in->dims().size()); } if (no_broadcast_flag) { LaunchSameDimsElementwiseCudaKernel(cuda_ctx, ins, outs, func); } else { axis = axis == -1 ? *std::max_element(dims_size.begin(), dims_size.end()) - *std::min_element(dims_size.begin(), dims_size.end()) : axis; LaunchBroadcastElementwiseCudaKernel(cuda_ctx, ins, outs, axis, func); } } } // namespace operators } // namespace paddle