From 890d6bc00f4384194f37784411be0d6d70e9498b Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Thu, 22 Apr 2021 14:27:52 +0800 Subject: [PATCH] Modify some contents for elementwise op impl (#32414) --- .../elementwise/elementwise_add_op.cu | 5 +-- .../elementwise/elementwise_op_impl.cu.h | 32 ++++++++++++------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index 0ca03fc32f..5c444e752e 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/platform/complex128.h" #include "paddle/fluid/platform/complex64.h" @@ -34,7 +33,9 @@ namespace operators { */ template struct CudaAddFunctor { - inline HOSTDEVICE T operator()(T args[]) const { return args[0] + args[1]; } + __device__ __forceinline__ T operator()(const T* args) const { + return args[0] + args[1]; + } }; template diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 36add21129..321826ec64 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -13,6 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/float16.h" + +#ifdef __HIPCC__ +#define ELEMENTWISE_BLOCK_SIZE 256 +#else +#define ELEMENTWISE_BLOCK_SIZE 512 +#endif + namespace paddle { namespace operators { @@ -90,8 +101,7 @@ struct ElementwiseDataWrapper { template __device__ void VectorizedKernelImpl( - ElementwiseDataWrapper data, int size, Functor func, - int tid) { + ElementwiseDataWrapper data, Functor func, int tid) { using VecType = CudaAlignedVector; VecType ins_vec[ET]; VecType out_vec; @@ -121,10 +131,9 @@ __device__ void VectorizedKernelImpl( data.store_vector(out_vec, tid); } -template -__device__ void ScalarKernelImpl(ElementwiseDataWrapper data, - int size, Functor func, int start, - int remain) { +template +__device__ void ScalarKernelImpl(ElementwiseDataWrapper data, + Functor func, int start, int remain) { T ins[ET]; T out; @@ -146,12 +155,11 @@ __global__ void VectorizedKernel(const T *__restrict__ in0, int tid = blockIdx.x * blockDim.x + threadIdx.x; int remain = size - VecSize * tid; remain = remain > 0 ? remain : 0; + auto data = ElementwiseDataWrapper(out, in0, in1); if (remain >= VecSize) { - auto data = ElementwiseDataWrapper(out, in0, in1); - VectorizedKernelImpl(data, size, func, tid); + VectorizedKernelImpl(data, func, tid); } else { - auto data = ElementwiseDataWrapper(out, in0, in1); - ScalarKernelImpl(data, size, func, tid * VecSize, remain); + ScalarKernelImpl(data, func, tid * VecSize, remain); } } @@ -162,7 +170,7 @@ __global__ void ScalarKernel(const T *__restrict__ in0, auto data = ElementwiseDataWrapper(out, in0, in1); int tid = blockIdx.x * blockDim.x + threadIdx.x; int remain = tid < size ? 1 : 0; - ScalarKernelImpl(data, size, func, tid, remain); + ScalarKernelImpl(data, func, tid, remain); } template @@ -173,7 +181,7 @@ void LaunchElementwiseCudaKernel( // calculate the max vec_size for all ins and outs auto size = ins[0]->numel(); int vec_size = GetVectorizedSize(ins, *outs); - int block_size = PADDLE_CUDA_THREAD_SIZE; + int block_size = ELEMENTWISE_BLOCK_SIZE; int grid_size = ((size + vec_size - 1) / vec_size + block_size - 1) / block_size; const T *in0 = ins[0]->data(); -- GitLab