From 2c18258316d568825dd722b74c63ebc21e28536b Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Sun, 18 Apr 2021 19:15:48 +0800 Subject: [PATCH] Unify the implementation of elementwise operation of same dimensions (#32148) --- .../elementwise/elementwise_add_op.cu | 58 ++--- .../elementwise/elementwise_op_impl.cu.h | 205 ++++++++++++++++++ 2 files changed, 222 insertions(+), 41 deletions(-) create mode 100644 paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 313607d975..0ca03fc32f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex64.h" #include "paddle/fluid/platform/float16.h" @@ -23,54 +24,29 @@ namespace plat = paddle::platform; namespace paddle { namespace operators { +/* + input: an array; + return: the result of the math functor + 1. For Unary Op, the length of input array is 1, + e.g. Relu: return args[0] > 0 ? args[0] : 0; + 2. For Binary Op, the length of input array is 2, + e.g. Add: return args[0] + args[1]; +*/ template -struct SameDimsElemwiseAdd< - platform::CUDADeviceContext, T, - typename std::enable_if::value && - !std::is_same::value>::type> { - void operator()(const framework::ExecutionContext& ctx, - const framework::Tensor* x, const framework::Tensor* y, - framework::Tensor* z) { - AddRangeFunctor functor(x->data(), y->data(), z->data()); - auto& dev_ctx = ctx.template device_context(); - platform::ForRange for_range(dev_ctx, - x->numel()); - for_range(functor); - } +struct CudaAddFunctor { + inline HOSTDEVICE T operator()(T args[]) const { return args[0] + args[1]; } }; template -struct SameDimsElemwiseAdd< - platform::CUDADeviceContext, T, - typename std::enable_if::value || - std::is_same::value>::type> { +struct SameDimsElemwiseAdd { void operator()(const framework::ExecutionContext& ctx, const framework::Tensor* x, const framework::Tensor* y, framework::Tensor* z) { - auto size = x->numel(); - int vec_size = sizeof(float4) / sizeof(T); - dim3 grid_size = - dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) / - PADDLE_CUDA_THREAD_SIZE, - 1); - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); - if (std::is_same::value) { - SameDimsElemwiseAddCUDAKernel<<< - grid_size, block_size, 0, - ctx.template device_context() - .stream()>>>(x->data(), y->data(), z->data(), - size); - } else { - const half* x2 = - reinterpret_cast(x->data()); - const half* y2 = - reinterpret_cast(y->data()); - half* z2 = reinterpret_cast(z->data()); - SameDimsElemwiseAddCUDAKernel<<< - grid_size, block_size, 0, - ctx.template device_context() - .stream()>>>(x2, y2, z2, size); - } + std::vector ins = {x, y}; + std::vector outs = {z}; + LaunchElementwiseCudaKernel( + ctx.template device_context(), ins, &outs, + CudaAddFunctor()); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h new file mode 100644 index 0000000000..36add21129 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -0,0 +1,205 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +namespace paddle { +namespace operators { + +enum ElementwiseType { kUnary = 1, kBinary = 2 }; + +template +struct alignas(sizeof(T) * Size) CudaAlignedVector { + T val[Size]; +}; + +template +int GetVectorizedSizeImpl(const T *pointer) { + uint64_t address = reinterpret_cast(pointer); + constexpr int vec4 = + std::alignment_of>::value; // NOLINT + constexpr int vec2 = + std::alignment_of>::value; // NOLINT + if (address % vec4 == 0) { + return 4; + } else if (address % vec2 == 0) { + return 2; + } + return 1; +} + +template +int GetVectorizedSize(const std::vector &ins, + const std::vector &outs) { + int vec_size = 4; + for (auto iter = ins.begin(); iter != ins.end(); ++iter) { + vec_size = + std::min(vec_size, GetVectorizedSizeImpl((*iter)->data())); + } + for (auto iter = outs.begin(); iter != outs.end(); ++iter) { + vec_size = + std::min(vec_size, GetVectorizedSizeImpl((*iter)->data())); + } + return vec_size; +} + +template +struct ElementwiseDataWrapper { + T *out; + const T *in0; + const T *in1; + __device__ ElementwiseDataWrapper(T *out, const T *in0, + const T *in1 = nullptr) + : out(out), in0(in0), in1(in1) {} + + using VecType = CudaAlignedVector; + + inline __device__ void load_vector(VecType args[], int idx) { + const VecType *x_vec = reinterpret_cast(in0); + args[0] = x_vec[idx]; + if (ET == ElementwiseType::kBinary) { + const VecType *y_vec = reinterpret_cast(in1); + args[1] = y_vec[idx]; + } + } + + inline __device__ void load_scalar(T args[], int idx) { + args[0] = in0[idx]; + if (ET == ElementwiseType::kBinary) { + args[1] = in1[idx]; + } + } + + inline __device__ void store_vector(VecType res, int idx) { + VecType *out_vec = reinterpret_cast(out); + out_vec[idx] = res; + } + + inline __device__ void store_scalar(T res, int idx) { out[idx] = res; } +}; + +template +__device__ void VectorizedKernelImpl( + ElementwiseDataWrapper data, int size, Functor func, + int tid) { + using VecType = CudaAlignedVector; + VecType ins_vec[ET]; + VecType out_vec; + T *ins_ptr[ET]; + T *out_ptr; +#pragma unroll + for (int i = 0; i < ET; ++i) { + ins_ptr[i] = reinterpret_cast(&(ins_vec[i])); + } + out_ptr = reinterpret_cast(&out_vec); + + // load + data.load_vector(ins_vec, tid); + +// compute +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + T ins[ET]; +#pragma unroll + for (int j = 0; j < ET; ++j) { + ins[j] = ins_ptr[j][i]; + } + out_ptr[i] = func(ins); + } + + // store + data.store_vector(out_vec, tid); +} + +template +__device__ void ScalarKernelImpl(ElementwiseDataWrapper data, + int size, Functor func, int start, + int remain) { + T ins[ET]; + T out; + + for (int i = 0; i < remain; ++i) { + int idx = start + i; + // load + data.load_scalar(ins, idx); + // compute + out = func(ins); + // store + data.store_scalar(out, idx); + } +} + +template +__global__ void VectorizedKernel(const T *__restrict__ in0, + const T *__restrict__ in1, T *out, int size, + Functor func) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int remain = size - VecSize * tid; + remain = remain > 0 ? remain : 0; + if (remain >= VecSize) { + auto data = ElementwiseDataWrapper(out, in0, in1); + VectorizedKernelImpl(data, size, func, tid); + } else { + auto data = ElementwiseDataWrapper(out, in0, in1); + ScalarKernelImpl(data, size, func, tid * VecSize, remain); + } +} + +template +__global__ void ScalarKernel(const T *__restrict__ in0, + const T *__restrict__ in1, T *out, int size, + Functor func) { + auto data = ElementwiseDataWrapper(out, in0, in1); + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int remain = tid < size ? 1 : 0; + ScalarKernelImpl(data, size, func, tid, remain); +} + +template +void LaunchElementwiseCudaKernel( + const platform::CUDADeviceContext &ctx, + const std::vector &ins, + std::vector *outs, Functor func) { + // calculate the max vec_size for all ins and outs + auto size = ins[0]->numel(); + int vec_size = GetVectorizedSize(ins, *outs); + int block_size = PADDLE_CUDA_THREAD_SIZE; + int grid_size = + ((size + vec_size - 1) / vec_size + block_size - 1) / block_size; + const T *in0 = ins[0]->data(); + const T *in1 = (ET == ElementwiseType::kBinary) ? ins[1]->data() : nullptr; + T *out = (*outs)[0]->data(); + // cuda kernel + auto stream = ctx.stream(); + switch (vec_size) { + case 4: + VectorizedKernel<<>>( + in0, in1, out, size, func); + break; + case 2: + VectorizedKernel<<>>( + in0, in1, out, size, func); + break; + case 1: + ScalarKernel<<>>(in0, in1, out, + size, func); + break; + default: + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported vectorized size: %d !", vec_size)); + break; + } +} + +} // namespace operators +} // namespace paddle -- GitLab