diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index 7491d6189ebde3b4c70f16c0b9b6e66eea535605..a708cbbfaacfc93fea2367a086a81a75ddeb9f71 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -34,7 +34,7 @@ limitations under the License. */ #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h" +#include "paddle/pten/kernels/funcs/cuda_kernel_config.h" namespace paddle { namespace operators { @@ -193,7 +193,7 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, // VectorizedRandomGenerator use curand_uniform4, so we only support // vec_size is 4; int vec_size = (platform::GetVectorizedSize(x_data) == 4) ? 4 : 1; - int block_size = pten::GetThreadsConfig(dev_ctx, x_numel, vec_size); + int block_size = pten::funcs::GetThreadsConfig(dev_ctx, x_numel, vec_size); int grid_size = ((x_numel + vec_size - 1) / vec_size + block_size - 1) / block_size; diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index a145848bad96c62828df78d98f8951cd75ef1dc5..3929699955a17f63d5fa2deead9ee0a3659e267f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -31,7 +31,7 @@ limitations under the License. */ // only can include the headers in paddle/pten/include dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" -#include "paddle/pten/kernels/cpu/elementwise_impl.h" +#include "paddle/pten/kernels/cpu/elementwise.h" #if defined(__NVCC__) || defined(__HIPCC__) #ifdef __NVCC__ diff --git a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h index 27897f10a3c6320c33542adc99484b8b5203e5d2..1d8acd5eca5d9cd56050b965e37152092f924b33 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h @@ -23,7 +23,7 @@ limitations under the License. */ // only can include the headers in paddle/top/api dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/include/core.h" -#include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h" +#include "paddle/pten/kernels/gpu/elementwise.h" namespace paddle { namespace operators { diff --git a/paddle/pten/kernels/cpu/CMakeLists.txt b/paddle/pten/kernels/cpu/CMakeLists.txt index 9bf3df598e4c03a38452fd8d0666bf10242bb7de..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/paddle/pten/kernels/cpu/CMakeLists.txt +++ b/paddle/pten/kernels/cpu/CMakeLists.txt @@ -1 +0,0 @@ -cc_library(math_cpu SRCS math.cc DEPS dense_tensor kernel_context kernel_factory eigen_function blas pten_transpose_cpu cast_kernel) diff --git a/paddle/pten/kernels/cpu/elementwise_impl.h b/paddle/pten/kernels/cpu/elementwise.h similarity index 100% rename from paddle/pten/kernels/cpu/elementwise_impl.h rename to paddle/pten/kernels/cpu/elementwise.h diff --git a/paddle/pten/kernels/cpu/math.cc b/paddle/pten/kernels/cpu/math.cc deleted file mode 100644 index b4642d475d56639aff2fde22e3b4ff8c37515d57..0000000000000000000000000000000000000000 --- a/paddle/pten/kernels/cpu/math.cc +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -namespace pten {} // namespace pten diff --git a/paddle/pten/kernels/cpu/math_kernel.cc b/paddle/pten/kernels/cpu/math_kernel.cc index 4f895d9514a97a8c710634a40c587e527179d8f9..2a696584bc78101a099230d1578b5d7701d337e6 100644 --- a/paddle/pten/kernels/cpu/math_kernel.cc +++ b/paddle/pten/kernels/cpu/math_kernel.cc @@ -18,7 +18,7 @@ #include "paddle/pten/backends/cpu/cpu_context.h" #include "paddle/pten/common/scalar.h" #include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/cpu/elementwise_impl.h" +#include "paddle/pten/kernels/cpu/elementwise.h" #include "paddle/pten/kernels/cpu/reduce.h" #include "paddle/pten/kernels/funcs/elementwise_functor.h" diff --git a/paddle/pten/kernels/funcs/cuda_kernel_config.h b/paddle/pten/kernels/funcs/cuda_kernel_config.h new file mode 100644 index 0000000000000000000000000000000000000000..27fbc1de55a353023689dc372922ad2ac9ac5933 --- /dev/null +++ b/paddle/pten/kernels/funcs/cuda_kernel_config.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" + +#ifdef __HIPCC__ +#define ELEMENTWISE_BLOCK_SIZE 256 +#else +#define ELEMENTWISE_BLOCK_SIZE 512 +#endif + +namespace pten { +namespace funcs { +/* +* According to NVIDIA, if number of threads per block is 64/128/256/512, +* cuda performs better. And number of blocks should be greater (at least +* 2x~4x) than number of SMs. Hence, SM count is took into account within +* this function to determine the right number of threads per block. +*/ +inline int GetThreadsConfig(const paddle::platform::CUDADeviceContext &ctx, + int64_t numel, + int vec_size) { + int threads = ELEMENTWISE_BLOCK_SIZE; + int sm_count = ctx.GetSMCount(); + int active_threads_num = numel / vec_size; + if (active_threads_num / (sm_count << 1) < ELEMENTWISE_BLOCK_SIZE) { + // Round up threads number into an exponential multiple of 2, while number + // of acitve blocks is about twice of SM, to acquire better performance. + threads = paddle::platform::RoundToPowerOfTwo(active_threads_num / + (sm_count << 1)); + } else if (active_threads_num / (sm_count << 2) < ELEMENTWISE_BLOCK_SIZE) { + // Round up threads number into an exponential multiple of 2, while number + // of acitve blocks is about 4 times of SM, to acquire better performance. + threads = paddle::platform::RoundToPowerOfTwo(active_threads_num / + (sm_count << 2)); + } + // Number of threads per block shall be larger than 64. + return std::max(64, threads); +} + +} // namespace funcs +} // namespace pten diff --git a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h b/paddle/pten/kernels/gpu/elementwise.h similarity index 61% rename from paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h rename to paddle/pten/kernels/gpu/elementwise.h index 134ad08913c2149989a5240ea3337b89124e711d..f78328c01a30d9fc74024898c57dabe6faaefa25 100644 --- a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h +++ b/paddle/pten/kernels/gpu/elementwise.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,11 +14,309 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" +#include "paddle/fluid/platform/aligned_vector.h" +#include "paddle/fluid/platform/function_traits.h" #include "paddle/pten/core/dense_tensor.h" -#include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h" +#include "paddle/pten/kernels/funcs/cuda_kernel_config.h" namespace pten { +namespace kps = paddle::operators::kernel_primitives; +enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; + +/* Packing scalar type T(float, int etc.) into Array type + for supporting multiple-output feature in elementwise system.*/ +template +using ConditionalT = + typename std::conditional_t>; + +template +struct ElementwisePrimitiveCaller { + __device__ inline void operator()(Functor func, + InT (*args)[VecSize], + OutT *result); +}; + +template +struct ElementwisePrimitiveCaller { + __device__ inline void operator()(Functor func, + InT (*args)[VecSize], + OutT *result) { + kps::ElementwiseAny( + result, args, func); + } +}; + +template +struct ElementwisePrimitiveCaller { + __device__ inline void operator()(Functor func, + InT (*args)[VecSize], + OutT *result) { + kps::ElementwiseUnary( + result, args[0], func); + } +}; + +template +struct ElementwisePrimitiveCaller { + __device__ inline void operator()(Functor func, + InT (*args)[VecSize], + OutT *result) { + kps::ElementwiseBinary( + result, args[0], args[1], func); + } +}; + +template +struct ElementwisePrimitiveCaller { + __device__ inline void operator()(Functor func, + InT (*args)[VecSize], + OutT *result) { + kps::ElementwiseTernary( + result, args[0], args[1], args[2], func); + } +}; + +template +struct ElementwiseWriteDataCaller { + __device__ __forceinline__ void operator()( + paddle::framework::Array outs, + ConditionalT src[VecSize], + int block_offset, + int num) { + OutT dst[NumOuts][VecSize]; +#pragma unroll + for (int i = 0; i < VecSize; ++i) { +#pragma unroll + for (int j = 0; j < NumOuts; ++j) { + dst[j][i] = (src[i])[j]; + } + } +#pragma unroll + for (int i = 0; i < NumOuts; ++i) { + kps::WriteData( + outs[i] + block_offset, dst[i], num); + } + } +}; + +template +struct ElementwiseWriteDataCaller { + __device__ __forceinline__ void operator()( + paddle::framework::Array outs, + OutT src[VecSize], + int block_offset, + int num) { + kps::WriteData( + outs[0] + block_offset, src, num); + } +}; + +template +__device__ void VectorizedElementwiseKernelImpl( + const paddle::framework::Array &in, + paddle::framework::Array outs, + int num, + int data_offset, + Functor func) { + InT args[Arity][VecSize]; + ConditionalT result[VecSize]; + +#pragma unroll + for (int i = 0; i < Arity; i++) { + kps::Init(args[i], static_cast(1.0f)); + kps::ReadData( + args[i], in[i] + data_offset, num); + } + + constexpr bool kCallElementwiseAny = + paddle::platform::FunctionTraits::has_pointer_args; + ElementwisePrimitiveCaller, + VecSize, + Functor, + Arity, + kCallElementwiseAny>()(func, args, result); + + ElementwiseWriteDataCaller()( + outs, result, data_offset, num); +} + +template +__global__ void VectorizedElementwiseKernel( + paddle::framework::Array ins, + paddle::framework::Array outs, + int size, + int main_offset, + Functor func) { + int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; + int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; + for (; data_offset < main_offset; data_offset += stride) { + VectorizedElementwiseKernelImpl( + ins, outs, VecSize * BLOCK_NUM_X, data_offset, func); + } + + int num = size - data_offset; + if (num > 0) { + VectorizedElementwiseKernelImpl(ins, outs, num, data_offset, func); + } +} + +template +int GetVectorizedSizeForTensors(const std::vector &ins, + const std::vector &outs) { + int vec_size = 4; + for (auto iter = ins.begin(); iter != ins.end(); ++iter) { + vec_size = std::min( + vec_size, paddle::platform::GetVectorizedSize((*iter)->data())); + } + for (auto iter = outs.begin(); iter != outs.end(); ++iter) { + vec_size = std::min( + vec_size, paddle::platform::GetVectorizedSize((*iter)->data())); + } + return vec_size; +} + +template +void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, + const std::vector &ins, + std::vector *outs, + Functor func) { + auto numel = ins[0]->numel(); + int block_size = funcs::GetThreadsConfig(ctx, numel, VecSize); + int grid_size = + ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size; + auto stream = ctx.stream(); + paddle::framework::Array ins_data; + paddle::framework::Array outs_data; + + for (int i = 0; i < Arity; ++i) { + ins_data[i] = ins[i]->data(); + } + for (int i = 0; i < NumOuts; ++i) { + outs_data[i] = (*outs)[i]->mutable_data(); + } +#ifdef PADDLE_WITH_XPU2 + block_size = 128; + grid_size = 8; + int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; + VectorizedElementwiseKernel<<>>( + ins_data, outs_data, numel, main_offset, func); +#else + int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; + VectorizedElementwiseKernel<<>>( + ins_data, outs_data, numel, main_offset, func); +#endif +} + +template +void LaunchSameDimsElementwiseCudaKernel( + const paddle::platform::CUDADeviceContext &ctx, + const std::vector &ins, + std::vector *outs, + Functor func) { + using Traits = paddle::platform::FunctionTraits; + const int kArity = + Traits::has_pointer_args ? static_cast(ET) : Traits::arity; + PADDLE_ENFORCE_EQ(ins.size(), + kArity, + paddle::platform::errors::InvalidArgument( + "The number of inputs is expected to be equal to the " + "arity of functor. But recieved: the number of inputs " + "is %d, the arity of functor is %d.", + ins.size(), + kArity)); + PADDLE_ENFORCE_EQ(outs->size(), + NumOuts, + paddle::platform::errors::InvalidArgument( + "Number of outputs shall equal to number of functions, " + "but number of outputs is %d, of functions is %d.", + outs->size(), + NumOuts)); + + if (NumOuts > 1) { + for (int i = 1; i < NumOuts; ++i) { + PADDLE_ENFORCE_EQ( + (*outs)[i]->dims(), + (*outs)[0]->dims(), + paddle::platform::errors::InvalidArgument( + "The shape of each output tensor shall be identical yet, " + "but %dth output tensor`s shape is not.", + i)); + } + } + + // calculate the max vec_size for all ins and outs + int vec_size = GetVectorizedSizeForTensors(ins, *outs); + switch (vec_size) { + case 4: + ElementwiseCudaKernel( + ctx, ins, outs, func); + break; + case 2: + ElementwiseCudaKernel( + ctx, ins, outs, func); + break; + case 1: + ElementwiseCudaKernel( + ctx, ins, outs, func); + break; + default: { + PADDLE_THROW(paddle::platform::errors::Unimplemented( + "Unsupported vectorized size: %d !", vec_size)); + break; + } + } +} + struct DimensionsTransform { using DimVector = std::vector; typedef void (*MergeFunctor)( @@ -532,4 +830,34 @@ void LaunchBroadcastElementwiseCudaKernel( } } +template +void LaunchElementwiseCudaKernel( + const paddle::platform::CUDADeviceContext &cuda_ctx, + const std::vector &ins, + std::vector *outs, + int axis, + Functor func) { + std::vector dims_size; + bool no_broadcast_flag = true; + for (auto *in : ins) { + no_broadcast_flag &= ins[0]->dims() == in->dims(); + dims_size.emplace_back(in->dims().size()); + } + if (no_broadcast_flag) { + LaunchSameDimsElementwiseCudaKernel( + cuda_ctx, ins, outs, func); + } else { + axis = axis == -1 + ? *std::max_element(dims_size.begin(), dims_size.end()) - + *std::min_element(dims_size.begin(), dims_size.end()) + : axis; + LaunchBroadcastElementwiseCudaKernel( + cuda_ctx, ins, outs, axis, func); + } +} + } // namespace pten diff --git a/paddle/pten/kernels/gpu/math_kernel.cu b/paddle/pten/kernels/gpu/math_kernel.cu index 051f7cb3bdd05026c96d82c41ff43c61a0934a07..f41934313d6740e44d198c411ccea05dabe1bda6 100644 --- a/paddle/pten/kernels/gpu/math_kernel.cu +++ b/paddle/pten/kernels/gpu/math_kernel.cu @@ -16,8 +16,8 @@ limitations under the License. */ #include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/kernels/funcs/elementwise_functor.h" +#include "paddle/pten/kernels/gpu/elementwise.h" #include "paddle/pten/kernels/gpu/reduce.h" -#include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h" #ifdef __NVCC__ #include "cub/cub.cuh" @@ -30,12 +30,9 @@ namespace cub = hipcub; #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" -#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/core/convert_utils.h" #include "paddle/pten/core/kernel_registry.h" -namespace kps = paddle::operators::kernel_primitives; - namespace pten { #define DEFINE_CUDA_ELEMENTWISE_OP(name) \ diff --git a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h deleted file mode 100644 index 83d662b14e7fc5b06d8a92886e1b5820cd342bb9..0000000000000000000000000000000000000000 --- a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h" -#include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h" - -namespace pten { - -template -void LaunchElementwiseCudaKernel( - const paddle::platform::CUDADeviceContext &cuda_ctx, - const std::vector &ins, - std::vector *outs, - int axis, - Functor func) { - std::vector dims_size; - bool no_broadcast_flag = true; - for (auto *in : ins) { - no_broadcast_flag &= ins[0]->dims() == in->dims(); - dims_size.emplace_back(in->dims().size()); - } - if (no_broadcast_flag) { - LaunchSameDimsElementwiseCudaKernel( - cuda_ctx, ins, outs, func); - } else { - axis = axis == -1 - ? *std::max_element(dims_size.begin(), dims_size.end()) - - *std::min_element(dims_size.begin(), dims_size.end()) - : axis; - LaunchBroadcastElementwiseCudaKernel( - cuda_ctx, ins, outs, axis, func); - } -} - -} // namespace pten diff --git a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h deleted file mode 100644 index ae384693249a48ab5576042d88dc3a3f546c154d..0000000000000000000000000000000000000000 --- a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" -#include "paddle/fluid/platform/aligned_vector.h" -#include "paddle/fluid/platform/function_traits.h" -#include "paddle/pten/core/dense_tensor.h" -#include "paddle/pten/kernels/funcs/elementwise_base.h" - -namespace pten { -namespace kps = paddle::operators::kernel_primitives; -enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 }; - -/* Packing scalar type T(float, int etc.) into Array type - for supporting multiple-output feature in elementwise system.*/ -template -using ConditionalT = - typename std::conditional_t>; - -template -struct ElementwisePrimitiveCaller { - __device__ inline void operator()(Functor func, - InT (*args)[VecSize], - OutT *result); -}; - -template -struct ElementwisePrimitiveCaller { - __device__ inline void operator()(Functor func, - InT (*args)[VecSize], - OutT *result) { - kps::ElementwiseAny( - result, args, func); - } -}; - -template -struct ElementwisePrimitiveCaller { - __device__ inline void operator()(Functor func, - InT (*args)[VecSize], - OutT *result) { - kps::ElementwiseUnary( - result, args[0], func); - } -}; - -template -struct ElementwisePrimitiveCaller { - __device__ inline void operator()(Functor func, - InT (*args)[VecSize], - OutT *result) { - kps::ElementwiseBinary( - result, args[0], args[1], func); - } -}; - -template -struct ElementwisePrimitiveCaller { - __device__ inline void operator()(Functor func, - InT (*args)[VecSize], - OutT *result) { - kps::ElementwiseTernary( - result, args[0], args[1], args[2], func); - } -}; - -template -struct ElementwiseWriteDataCaller { - __device__ __forceinline__ void operator()( - paddle::framework::Array outs, - ConditionalT src[VecSize], - int block_offset, - int num) { - OutT dst[NumOuts][VecSize]; -#pragma unroll - for (int i = 0; i < VecSize; ++i) { -#pragma unroll - for (int j = 0; j < NumOuts; ++j) { - dst[j][i] = (src[i])[j]; - } - } -#pragma unroll - for (int i = 0; i < NumOuts; ++i) { - kps::WriteData( - outs[i] + block_offset, dst[i], num); - } - } -}; - -template -struct ElementwiseWriteDataCaller { - __device__ __forceinline__ void operator()( - paddle::framework::Array outs, - OutT src[VecSize], - int block_offset, - int num) { - kps::WriteData( - outs[0] + block_offset, src, num); - } -}; - -} // namespace pten diff --git a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h b/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h deleted file mode 100644 index f37e3b0b5e3b36c0381810ef167b4038809bfad8..0000000000000000000000000000000000000000 --- a/paddle/pten/kernels/hybird/cuda/elementwise/elementwise_no_broadcast.cu.h +++ /dev/null @@ -1,253 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h" - -#ifdef __HIPCC__ -#define ELEMENTWISE_BLOCK_SIZE 256 -#else -#define ELEMENTWISE_BLOCK_SIZE 512 -#endif - -namespace pten { - -/* -* According to NVIDIA, if number of threads per block is 64/128/256/512, -* cuda performs better. And number of blocks should be greater (at least -* 2x~4x) than number of SMs. Hence, SM count is took into account within -* this function to determine the right number of threads per block. -*/ -inline int GetThreadsConfig(const paddle::platform::CUDADeviceContext &ctx, - int64_t numel, - int vec_size) { - int threads = ELEMENTWISE_BLOCK_SIZE; - int sm_count = ctx.GetSMCount(); - int active_threads_num = numel / vec_size; - if (active_threads_num / (sm_count << 1) < ELEMENTWISE_BLOCK_SIZE) { - // Round up threads number into an exponential multiple of 2, while number - // of acitve blocks is about twice of SM, to acquire better performance. - threads = paddle::platform::RoundToPowerOfTwo(active_threads_num / - (sm_count << 1)); - } else if (active_threads_num / (sm_count << 2) < ELEMENTWISE_BLOCK_SIZE) { - // Round up threads number into an exponential multiple of 2, while number - // of acitve blocks is about 4 times of SM, to acquire better performance. - threads = paddle::platform::RoundToPowerOfTwo(active_threads_num / - (sm_count << 2)); - } - // Number of threads per block shall be larger than 64. - return std::max(64, threads); -} - -template -__device__ void VectorizedElementwiseKernelImpl( - const paddle::framework::Array &in, - paddle::framework::Array outs, - int num, - int data_offset, - Functor func) { - InT args[Arity][VecSize]; - ConditionalT result[VecSize]; - -#pragma unroll - for (int i = 0; i < Arity; i++) { - kps::Init(args[i], static_cast(1.0f)); - kps::ReadData( - args[i], in[i] + data_offset, num); - } - - constexpr bool kCallElementwiseAny = - paddle::platform::FunctionTraits::has_pointer_args; - ElementwisePrimitiveCaller, - VecSize, - Functor, - Arity, - kCallElementwiseAny>()(func, args, result); - - ElementwiseWriteDataCaller()( - outs, result, data_offset, num); -} - -template -__global__ void VectorizedElementwiseKernel( - paddle::framework::Array ins, - paddle::framework::Array outs, - int size, - int main_offset, - Functor func) { - int data_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize; - int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize; - for (; data_offset < main_offset; data_offset += stride) { - VectorizedElementwiseKernelImpl( - ins, outs, VecSize * BLOCK_NUM_X, data_offset, func); - } - - int num = size - data_offset; - if (num > 0) { - VectorizedElementwiseKernelImpl(ins, outs, num, data_offset, func); - } -} - -template -int GetVectorizedSizeForTensors(const std::vector &ins, - const std::vector &outs) { - int vec_size = 4; - for (auto iter = ins.begin(); iter != ins.end(); ++iter) { - vec_size = std::min( - vec_size, paddle::platform::GetVectorizedSize((*iter)->data())); - } - for (auto iter = outs.begin(); iter != outs.end(); ++iter) { - vec_size = std::min( - vec_size, paddle::platform::GetVectorizedSize((*iter)->data())); - } - return vec_size; -} - -template -void ElementwiseCudaKernel(const paddle::platform::CUDADeviceContext &ctx, - const std::vector &ins, - std::vector *outs, - Functor func) { - auto numel = ins[0]->numel(); - int block_size = GetThreadsConfig(ctx, numel, VecSize); - int grid_size = - ((numel + VecSize - 1) / VecSize + block_size - 1) / block_size; - auto stream = ctx.stream(); - paddle::framework::Array ins_data; - paddle::framework::Array outs_data; - - for (int i = 0; i < Arity; ++i) { - ins_data[i] = ins[i]->data(); - } - for (int i = 0; i < NumOuts; ++i) { - outs_data[i] = (*outs)[i]->mutable_data(); - } -#ifdef PADDLE_WITH_XPU2 - block_size = 128; - grid_size = 8; - int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; - VectorizedElementwiseKernel<<>>( - ins_data, outs_data, numel, main_offset, func); -#else - int main_offset = (numel / (VecSize * block_size)) * VecSize * block_size; - VectorizedElementwiseKernel<<>>( - ins_data, outs_data, numel, main_offset, func); -#endif -} - -template -void LaunchSameDimsElementwiseCudaKernel( - const paddle::platform::CUDADeviceContext &ctx, - const std::vector &ins, - std::vector *outs, - Functor func) { - using Traits = paddle::platform::FunctionTraits; - const int kArity = - Traits::has_pointer_args ? static_cast(ET) : Traits::arity; - PADDLE_ENFORCE_EQ(ins.size(), - kArity, - paddle::platform::errors::InvalidArgument( - "The number of inputs is expected to be equal to the " - "arity of functor. But recieved: the number of inputs " - "is %d, the arity of functor is %d.", - ins.size(), - kArity)); - PADDLE_ENFORCE_EQ(outs->size(), - NumOuts, - paddle::platform::errors::InvalidArgument( - "Number of outputs shall equal to number of functions, " - "but number of outputs is %d, of functions is %d.", - outs->size(), - NumOuts)); - - if (NumOuts > 1) { - for (int i = 1; i < NumOuts; ++i) { - PADDLE_ENFORCE_EQ( - (*outs)[i]->dims(), - (*outs)[0]->dims(), - paddle::platform::errors::InvalidArgument( - "The shape of each output tensor shall be identical yet, " - "but %dth output tensor`s shape is not.", - i)); - } - } - - // calculate the max vec_size for all ins and outs - int vec_size = GetVectorizedSizeForTensors(ins, *outs); - switch (vec_size) { - case 4: - ElementwiseCudaKernel( - ctx, ins, outs, func); - break; - case 2: - ElementwiseCudaKernel( - ctx, ins, outs, func); - break; - case 1: - ElementwiseCudaKernel( - ctx, ins, outs, func); - break; - default: { - PADDLE_THROW(paddle::platform::errors::Unimplemented( - "Unsupported vectorized size: %d !", vec_size)); - break; - } - } -} - -} // namespace pten diff --git a/paddle/pten/kernels/hybird/general/manipulation.h b/paddle/pten/kernels/hybird/general/manipulation.h deleted file mode 100644 index 85f6b613ac60947ae7c7a993bd3ab4210bf8f667..0000000000000000000000000000000000000000 --- a/paddle/pten/kernels/hybird/general/manipulation.h +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/pten/core/dense_tensor.h" - -namespace pten { -namespace general { - -inline void SetXShape(const DenseTensor& x, DenseTensor* xshape) { - const auto& in_dims = x.meta().dims; - std::vector xshape_dims(in_dims.size() + 1); - xshape_dims[0] = 0; - for (int i = 0; i < in_dims.size(); ++i) { - xshape_dims[i + 1] = in_dims[i]; - } - xshape->Resize(paddle::framework::make_ddim(xshape_dims)); - xshape->ResetLoD(x.meta().lod); -} - -} // namespace general -} // namespace pten