diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index b551629169deed66a1a79636287569995726c4be..63f62347b81c6aee1ae9aea6397081885f8781da 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -69,6 +69,15 @@ struct SameDimsElemwiseAdd< } }; +template +struct BroadcastElemwiseAdd { + void operator()(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z) { + default_elementwise_add(ctx, x, y, z); + } +}; + class ElementwiseAddOpMaker : public ElementwiseOpMaker { protected: std::string GetName() const override { return "Add"; } diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index dc9c18ba038861b763cb52863ddae8ac69db5022..7b42803aa51ec6e34ee1484ee46486091a560a9f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex64.h" @@ -51,6 +52,20 @@ struct SameDimsElemwiseAdd { } }; +template +struct BroadcastElemwiseAdd { + void operator()(const framework::ExecutionContext& ctx, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* out) { + std::vector ins = {x, y}; + int axis = ctx.Attr("axis"); + axis = axis == -1 ? std::abs(x->dims().size() - y->dims().size()) : axis; + LaunchBroadcastElementwiseCudaKernel( + ctx.template device_context(), ins, out, + CudaAddFunctor(), axis); + } +}; + template static __global__ void SimpleElemwiseAddGradCUDAKernel( const T* __restrict__ dout, int size, int vec_size, T* dx, T* dy) { diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index abea9da9423553e177581a30c02fe73dc50369c6..57f6629702214f15caeaf281a68e5e6eb2e61567 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -20,11 +20,13 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" + #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #ifdef __NVCC__ #include #include #include "cub/cub.cuh" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #endif #ifdef __HIPCC__ #include @@ -60,6 +62,13 @@ struct SameDimsElemwiseAdd { framework::Tensor *z); }; +template +struct BroadcastElemwiseAdd { + void operator()(const framework::ExecutionContext &ctx, + const framework::Tensor *x, const framework::Tensor *y, + framework::Tensor *z); +}; + template class ElementwiseAddKernel : public framework::OpKernel { public: @@ -73,7 +82,8 @@ class ElementwiseAddKernel : public framework::OpKernel { SameDimsElemwiseAdd same_dims_add; same_dims_add(ctx, x, y, z); } else { - default_elementwise_add(ctx, x, y, z); + BroadcastElemwiseAdd broadcast_add; + broadcast_add(ctx, x, y, z); } } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..c9657a1b9db04675b2ebe17ecfc347df60eb7e29 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h @@ -0,0 +1,468 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.1 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.1 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast_impl.cu.h" + +namespace paddle { +namespace operators { + +struct DimensionsTransform { + using DimVector = std::vector; + typedef void (*MergeFunctor)(bool &, std::vector &, DimVector &, + int, int); + int64_t dim_size; + DimVector out_dims; + std::vector in_dims; + + private: + // 1. To compensate the lackage of input_tensors` dimension; + void InputDimensionsExtend(int N, int axis) { + for (auto &in_dim : in_dims) { + int64_t in_idx = 0; + if (in_dim.size() < dim_size) { + DimVector tmp_dim(dim_size, 1); + do { + if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) { + tmp_dim[axis] = in_dim[in_idx]; + in_idx++; + axis++; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The %dth dimension of input tensor is expected to be equal " + "with" + "the %dth dimension of output tensor %d or 1, but recieved " + "%d.\n", + in_idx + 1, axis + 1, out_dims[axis], in_dim[in_idx])); + } + } while (in_idx < in_dim.size()); + in_dim.resize(dim_size); + std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin()); + } else { + do { + if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) { + in_idx++; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The %dth dimension of input tensor is expected to be equal " + "with" + "the %dth dimension of output tensor %d or 1, but recieved " + "%d.\n", + in_idx + 1, in_idx + 1, out_dims[in_idx], in_dim[in_idx])); + } + } while (in_idx < dim_size); + } + std::reverse(in_dim.begin(), in_dim.end()); + } + std::reverse(out_dims.begin(), out_dims.end()); + } + + template + __inline__ void DimensionsReorganise(MergeFunctor merge_func, int N) { + auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) { + (*vec)[m_idx - 1] = + std::accumulate(vec->begin() + l_idx, vec->begin() + m_idx, 1, + std::multiplies()); + vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1); + }; + + int64_t i = 0; + while (i < dim_size) { + int cnt = 0; + int low_idx = i; + bool equal = true; + do { + merge_func(equal, in_dims, out_dims, i, N); + if (equal) { + i++; + cnt++; + } else { + break; + } + } while (i < dim_size); + + if (cnt > 1) { + for (auto &in_dim : in_dims) { + VectorReorganise(&in_dim, low_idx, i); + } + VectorReorganise(&out_dims, low_idx, i); + dim_size -= --cnt; + i -= cnt; + } else if (cnt < 1) { + i++; + } + } + } + + public: + explicit DimensionsTransform( + const std::vector &ins, + const framework::DDim &dims, int axis) { + const int N = ins.size(); + dim_size = dims.size(); + out_dims = framework::vectorize(dims); + in_dims.resize(N); + for (int j = 0; j < N; ++j) { + in_dims[j] = framework::vectorize(ins[j]->dims()); + } + InputDimensionsExtend(N, axis); + + auto merge_sequential_dims = [](bool &equal, + std::vector &in_dims, + DimVector &out, int i, int num) { + for (int j = 1; j < num; ++j) { + equal = (in_dims[0][i] == in_dims[j][i]) ? true : false; + } + }; + auto merge_sequential_one_dims = [](bool &equal, + std::vector &in_dims, + DimVector &out, int i, int num) { + equal = in_dims[0][i] == 1; + if (equal) { + for (int j = 1; j < num; ++j) { + equal = in_dims[j][i] == out[i]; + } + } + }; + // To Merge the dimensions of input_tensors while the consequtive + // equal-dimensions appears. + MergeFunctor merge_ptr = merge_sequential_dims; + DimensionsReorganise(merge_ptr, N); + + int min_idx = 0; + int min_val = std::accumulate(in_dims[0].begin(), in_dims[0].end(), 1, + std::multiplies()); + for (int j = 1; j < N; ++j) { + int temp = std::accumulate(in_dims[j].begin(), in_dims[j].end(), 1, + std::multiplies()); + min_val = min_val > temp ? temp : min_val; + min_idx = min_val == temp ? j : min_idx; + } + std::swap(in_dims[0], in_dims[min_idx]); + + // To Merge the dimension of input_tensors while the consequtive + // 1-value-dimensions appears. + merge_ptr = merge_sequential_one_dims; + DimensionsReorganise(merge_ptr, N); + std::swap(in_dims[min_idx], in_dims[0]); + } +}; + +struct CalculateInputStrides { + std::vector> strides; + std::vector divmoders; + + private: + // To calculate the strides of each input_tensor. + __inline__ void CalculateStrides( + int N, int dim_size, const std::vector> &in_dims) { + for (int j = 0; j < N; ++j) { + for (int i = 0; i < dim_size; ++i) { + strides[j][i] = in_dims[j][i] == 1 ? 0 : strides[j][i]; + strides[j][i] = + (i != 0 && strides[j][i] != 0) + ? std::accumulate(in_dims[j].begin(), in_dims[j].begin() + i, 1, + std::multiplies()) + : strides[j][i]; + } + } + } + + public: + explicit CalculateInputStrides( + const int64_t &dim_size, const std::vector> &in_dims, + const std::vector &out_dims) { + const auto N = in_dims.size(); + divmoders.resize(dim_size); + strides.resize(N, std::vector(dim_size, 1)); + + for (int i = 0; i < dim_size; ++i) { + divmoders[i] = FastDivMod(out_dims[i]); + } + CalculateStrides(N, dim_size, in_dims); + } +}; + +template +struct BroadcastArgsWarpper { + using DimsVec = CudaAlignedVector; + + T *out_data; + const T *__restrict__ in_data[ET]; + uint32_t strides[ET][framework::DDim::kMaxRank]; + bool no_broadcast[ET]; + FastDivMod divmoders[kDims]; + uint32_t scalar_offset; + + HOSTDEVICE BroadcastArgsWarpper( + const std::vector &ins, + const CalculateInputStrides &offset_calculator, framework::Tensor *out, + int scalar_offset) + : scalar_offset(scalar_offset) { + for (int j = 0; j < ET; ++j) { + in_data[j] = ins[j]->data(); + no_broadcast[j] = ins[j]->dims() == out->dims() ? true : false; + memcpy(strides[j], offset_calculator.strides[j].data(), + kDims * sizeof(uint32_t)); + } + out_data = out->data(); + memcpy(divmoders, offset_calculator.divmoders.data(), + kDims * sizeof(FastDivMod)); + } + + __device__ __forceinline__ uint32_t GetDivmodOffset(int idx, int in_idx) { + uint32_t offset = 0; + +#pragma unroll(kDims) + for (int i = 0; i < kDims; ++i) { + auto fast_divmoder = divmoders[i].Divmod(idx); + idx = fast_divmoder.val[0]; + offset += fast_divmoder.val[1] * strides[in_idx][i]; + } + return offset; + } + + __device__ __forceinline__ void CommonVector(DimsVec args[], int tid, + int idx) { + const DimsVec *__restrict__ vec_data = + reinterpret_cast(in_data[idx]); + args[idx] = vec_data[tid]; + } + + __device__ __forceinline__ void DivmodVector(DimsVec args[], int tid, + int idx) { + int index = tid * VecSize; + + for (int i = 0; i < VecSize; ++i) { + uint32_t offset = GetDivmodOffset(index + i, idx); + args[idx].val[i] = in_data[idx][offset]; + } + } + + __device__ __forceinline__ void CommonScalar(T args[], int tid, int idx) { + args[idx] = in_data[idx][tid + scalar_offset]; + } + + __device__ __forceinline__ void DivmodScalar(T args[], int tid, int idx) { + auto offset = GetDivmodOffset(tid + scalar_offset, idx); + args[idx] = in_data[idx][offset]; + } + + __device__ __forceinline__ void LoadVector(DimsVec args[], int tid) { +#pragma unroll(ET) + for (int j = 0; j < ET; ++j) { + if (no_broadcast[j]) { + CommonVector(args, tid, j); + } else { + DivmodVector(args, tid, j); + } + } + } + + __device__ __forceinline__ void LoadScalar(T args[], int tid) { +#pragma unroll(ET) + for (int j = 0; j < ET; ++j) { + if (no_broadcast[j]) { + CommonScalar(args, tid, j); + } else { + DivmodScalar(args, tid, j); + } + } + } + + __device__ __forceinline__ void StoreVector(DimsVec args[], int tid) { + DimsVec *vec_out = reinterpret_cast(out_data); + vec_out[tid] = args[0]; + } + + __device__ __forceinline__ void StoreScalar(T args[], int tid) { + out_data[scalar_offset + tid] = args[0]; + } +}; + +template +__device__ inline void ScalarizedBroadcastKernelImpl( + BroadcastArgsWarpper data_transfer, int tid) { + T args[ET]; + data_transfer.LoadScalar(args, tid); + +#pragma unroll(ET) + for (int j = 1; j < ET; ++j) { + args[0] += args[j]; + } + data_transfer.StoreScalar(args, tid); +} + +template +__device__ inline void VectorizedBroadcastKernelImpl( + BroadcastArgsWarpper data_transfer, int tid) { + using VecT = CudaAlignedVector; + VecT args[ET]; + data_transfer.LoadVector(args, tid); + +#pragma unroll(ET) + for (int j = 1; j < ET; ++j) { +#pragma unroll(VecSize) + for (int i = 0; i < VecSize; ++i) { + args[0].val[i] += args[j].val[i]; + } + } + data_transfer.StoreVector(args, tid); +} + +template +__global__ void ElementwiseBroadcastKernel(BroadcastArgsWarpper data_transfer, + int main_tid, int tail_tid) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + + // Aimming at vectorized calculation of major data whose length is max + // multipler of VecSize. + if (tid < main_tid) { + VectorizedBroadcastKernelImpl( + data_transfer, tid); + } + // Aimming at scalar calculation of rest data whose lenght cannot fulfill + // VecSize. + if (tid < tail_tid) { + ScalarizedBroadcastKernelImpl(data_transfer, + tid); + } +} + +template +void LaunchBroadcastKernelForDifferentDimSize( + const platform::CUDADeviceContext &ctx, + const std::vector &ins, framework::Tensor *out, + int axis) { + int numel = out->numel(); + const int threads = 256; + int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads; + int main_tid = numel / VecSize; + int tail_tid = numel % VecSize; + int vec_len = main_tid * VecSize; + auto stream = ctx.stream(); + + const auto merge_dims = DimensionsTransform(ins, out->dims(), axis); + const auto offset_calculator = CalculateInputStrides( + merge_dims.dim_size, merge_dims.in_dims, merge_dims.out_dims); + + switch (merge_dims.dim_size) { + case 1: { + auto data_transfer = BroadcastArgsWarpper( + ins, offset_calculator, out, vec_len); + ElementwiseBroadcastKernel<<>>( + data_transfer, main_tid, tail_tid); + break; + } + case 2: { + auto data_transfer = BroadcastArgsWarpper( + ins, offset_calculator, out, vec_len); + ElementwiseBroadcastKernel<<>>( + data_transfer, main_tid, tail_tid); + break; + } + case 3: { + auto data_transfer = BroadcastArgsWarpper( + ins, offset_calculator, out, vec_len); + ElementwiseBroadcastKernel<<>>( + data_transfer, main_tid, tail_tid); + break; + } + case 4: { + auto data_transfer = BroadcastArgsWarpper( + ins, offset_calculator, out, vec_len); + ElementwiseBroadcastKernel<<>>( + data_transfer, main_tid, tail_tid); + break; + } + case 5: { + auto data_transfer = BroadcastArgsWarpper( + ins, offset_calculator, out, vec_len); + ElementwiseBroadcastKernel<<>>( + data_transfer, main_tid, tail_tid); + break; + } + case 6: { + auto data_transfer = BroadcastArgsWarpper( + ins, offset_calculator, out, vec_len); + ElementwiseBroadcastKernel<<>>( + data_transfer, main_tid, tail_tid); + break; + } + case 7: { + auto data_transfer = BroadcastArgsWarpper( + ins, offset_calculator, out, vec_len); + ElementwiseBroadcastKernel<<>>( + data_transfer, main_tid, tail_tid); + break; + } + case 8: { + auto data_transfer = BroadcastArgsWarpper( + ins, offset_calculator, out, vec_len); + ElementwiseBroadcastKernel<<>>( + data_transfer, main_tid, tail_tid); + break; + } + default: { + PADDLE_THROW(platform::errors::InvalidArgument( + "The maximum dimension of input tensor is expected to be less than " + "%d, but recieved %d.\n", + merge_dims.dim_size, framework::DDim::kMaxRank)); + } + } +} + +template +void LaunchBroadcastElementwiseCudaKernel( + const platform::CUDADeviceContext &ctx, + const std::vector &ins, framework::Tensor *out, + Functor func, int axis) { + int in_vec_size = 4; + for (auto *in : ins) { + auto temp_size = GetVectorizedSizeImpl(in->data()); + in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size) + : in_vec_size; + } + int out_vec_size = GetVectorizedSizeImpl(out->data()); + int vec_size = std::min(out_vec_size, in_vec_size); + + switch (vec_size) { + case 4: { + LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, axis); + break; + } + case 2: { + LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, axis); + break; + } + default: { + LaunchBroadcastKernelForDifferentDimSize(ctx, ins, out, axis); + break; + } + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_op_broadcast_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_broadcast_impl.cu.h new file mode 100644 index 0000000000000000000000000000000000000000..083bc6a1378ae39ddda7b00f6514f2162fa9c914 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_op_broadcast_impl.cu.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.1 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.1 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" + +#define INT_BITS 32 + +namespace paddle { +namespace operators { + +struct FastDivMod { + // 1st value represents the result of input number divides by recorded divisor + // 2nd value represents the result of input number modulo by recorded divisor + using DivModT = CudaAlignedVector; + + FastDivMod() {} + HOSTDEVICE FastDivMod(uint32_t d) : divisor(d) { + static_assert(sizeof(unsigned int) == 4, + "Only Support 32-bit unsigned int."); + + for (shift_val = 0; shift_val < INT_BITS; ++shift_val) { + auto shift_limit = 1 << shift_val; + if (shift_limit >= divisor) break; + } + uint64_t long_one = 1; + uint64_t temp_div = + ((long_one << INT_BITS) * ((long_one << shift_val) - divisor)) / + divisor + + 1; + multiplier = temp_div; + } + + __device__ __forceinline__ uint32_t Div(uint32_t n) const { + uint32_t t = __umulhi(n, multiplier); + return (t + n) >> shift_val; + } + + __device__ __forceinline__ DivModT Divmod(uint32_t n) { + uint32_t q = Div(n); + DivModT result = {q, n - q * divisor}; + return result; + } + + int32_t divisor; + int32_t shift_val; + uint32_t multiplier; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 38b1afbdc3342e8bc4d9901b64bae808fd9d3915..449863f93f2a733e687d558b885248fe35be327c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -197,6 +197,7 @@ void LaunchElementwiseCudaKernel( OutT *out = (*outs)[0]->data(); // cuda kernel auto stream = ctx.stream(); + switch (vec_size) { case 4: VectorizedKernel<<>>(