diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index bd91142882f7650f4dedbee1d3001fe277b12967..331867617bd78a31c011e3d42ee9c1be565e41b6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -147,10 +147,10 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, } else if (dx_data != dout_data && dy_data != dout_data) { auto size = x->numel(); int vec_size = max(static_cast(sizeof(float4) / sizeof(T)), 1); - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); dim3 grid_size = - dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) / - PADDLE_CUDA_THREAD_SIZE, + dim3(((size + vec_size - 1) / vec_size + ELEMENTWISE_BLOCK_SIZE - 1) / + ELEMENTWISE_BLOCK_SIZE, 1); SimpleElemwiseAddGradCUDAKernel< T><< #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#ifdef __NVCC__ -#include -#include -#include "cub/cub.cuh" - -#endif -#ifdef __HIPCC__ -#include -#include -#include -namespace cub = hipcub; -#endif -#endif - namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index 8853fd609f77c968c9b1758e951e6f9ba39aa10a..ce487f284d9f22de9755c61de4c95b8fd2922185 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "paddle/fluid/operators/elementwise/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/platform/complex.h" @@ -22,24 +23,6 @@ namespace plat = paddle::platform; namespace paddle { namespace operators { -template -struct CudaDivFunctor { - inline HOSTDEVICE T operator()(const T* args) const { - return args[0] / args[1]; - } -}; - -template -struct CudaDivFunctor::value>> { - inline HOSTDEVICE T operator()(const T* args) const { - PADDLE_ENFORCE(args[1] != 0, - "Invalid Argument Error: Integer division by zero " - "encountered in divide. Please check the input value."); - return args[0] / args[1]; - } -}; - template class ElementwiseDivKernel : public framework::OpKernel { @@ -52,7 +35,7 @@ class ElementwiseDivKernel int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, CudaDivFunctor()); + cuda_ctx, ins, &outs, axis, DivFunctor()); } }; @@ -124,10 +107,10 @@ elementwise_div_grad(const framework::ExecutionContext& ctx, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); auto size = x->numel(); dim3 grid_size = - dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); SimpleElemwiseDivGradCUDAKernel< T><<().stream()>>>( diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index a0b9633acb2e5956754d07c53bcdcea7b2896c07..0ec42e54e14442b34463d14af2604bb7f1ef9ab7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" #include "paddle/fluid/operators/math/blas.h" diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu index a0510d95700b27ba360c48f06ac3f99752b993f2..41a0ae00f270d7ac8140133867ad1650d6c38e88 100644 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu @@ -11,25 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -namespace ops = paddle::operators; -namespace plat = paddle::platform; - namespace paddle { namespace operators { -template -struct CudaFloorDivFunctor { - inline HOSTDEVICE T operator()(const T argv[]) const { - PADDLE_ENFORCE(argv[1] != 0, - "InvalidArgument: divide by zero " - "encountered in floor-divide ops, please check.\n"); - return static_cast(std::trunc(argv[0] / argv[1])); - } -}; - template class ElementwiseFloorDivKernel : public framework::OpKernel { @@ -42,13 +30,16 @@ class ElementwiseFloorDivKernel int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, CudaFloorDivFunctor()); + cuda_ctx, ins, &outs, axis, FloorDivFunctor()); } }; } // namespace operators } // namespace paddle +namespace ops = paddle::operators; +namespace plat = paddle::platform; + REGISTER_OP_CUDA_KERNEL( elementwise_floordiv, ops::ElementwiseFloorDivKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h index bc3c2994c847cb65fb6b476c2bbf8076edfffc1d..ae8d2d8625c586ca6afc97c2538b7d83b687547c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h @@ -15,54 +15,13 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { -template -struct FloorDivFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { -#if defined(__HIPCC__) || defined(__CUDA_ARCH__) - if (b == 0) { - printf("Error: Divide by zero encounter in floor_divide\n"); -#ifdef __HIPCC__ - abort(); -#else - asm("trap;"); -#endif - } -#else - if (b == 0) - PADDLE_THROW(platform::errors::InvalidArgument( - "Divide by zero encounter in floor_divide")); -#endif - return static_cast(std::trunc(a / b)); - } -}; - -template -struct InverseFloorDivFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { -#if defined(__HIPCC__) || defined(__CUDA_ARCH__) - if (a == 0) { - printf("Error: Divide by zero encounter in floor_divide\n"); -#ifdef __HIPCC__ - abort(); -#else - asm("trap;"); -#endif - } -#else - if (a == 0) - PADDLE_THROW(platform::errors::InvalidArgument( - "Divide by zero encounter in floor_divide")); -#endif - return static_cast(std::trunc(b / a)); - } -}; - template void elementwise_floor_div(const framework::ExecutionContext &ctx, const framework::Tensor *x, diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..abac43a2616f09b34b56aadbac7c6614f7a59fad --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -0,0 +1,117 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +// Define the binary functors used in elementwise ops. + +// Add +template +struct AddFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a + b; } +}; +template +struct InverseAddFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b + a; } +}; + +// Subtract +template +struct SubFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a - b; } +}; +template +struct InverseSubFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b - a; } +}; + +// Multiply +template +struct MulFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; } +}; +template +struct InverseMulFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b * a; } +}; + +// Divide +#define DIV_ERROR_INFO \ + "InvalidArgumentError: Integer division by zero encountered in " \ + "(floor) divide. Please check the input value." + +template +struct DivFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; } +}; + +template +struct DivFunctor::value>::type> { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + // For int32/int64, need to check whether the divison is zero. + PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO); + return a / b; + } +}; + +template +struct InverseDivFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b / a; } +}; + +// Floor Divide +template +struct FloorDivFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO); + return static_cast(std::trunc(a / b)); + } +}; + +template +struct InverseFloorDivFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + PADDLE_ENFORCE(a != 0, DIV_ERROR_INFO); + return static_cast(std::trunc(b / a)); + } +}; + +#undef DIV_ERROR_INFO + +// Maximum +template +struct MaxFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + return a > b ? a : b; + } +}; + +// Minmum +template +struct MinFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + return a < b ? a : b; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index d4b5d98d5b0b345119f833e5a684d8f0b6e1f310..9657e1896e3349079da6393963a30d7116b96638 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -11,21 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "paddle/fluid/operators/elementwise/elementwise_max_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -namespace ops = paddle::operators; - namespace paddle { namespace operators { -template -struct CudaMaxFunctor { - inline HOSTDEVICE T operator()(const T* args) const { - return (args[0] > args[1] ? args[0] : args[1]); - } -}; - template class ElementwiseMaxKernel : public framework::OpKernel { @@ -38,13 +30,15 @@ class ElementwiseMaxKernel int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, CudaMaxFunctor()); + cuda_ctx, ins, &outs, axis, MaxFunctor()); } }; } // namespace operators } // namespace paddle +namespace ops = paddle::operators; + REGISTER_OP_CUDA_KERNEL( elementwise_max, ops::ElementwiseMaxKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index abdb1b9671de80d02b9a6a788088f47929fcc6f0..8ee8fe923a81141f782fac4d8427aff53bef893e 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -14,17 +14,13 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" namespace paddle { namespace operators { -template -struct MaxFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return a > b ? a : b; } -}; - template class ElementwiseMaxKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index 4a99f7e36705f0d96b200d20e880bebf5b5b2186..eed6f72b04fb9bc58e8a4a59a50102f276d6816a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -11,21 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "paddle/fluid/operators/elementwise/elementwise_min_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -namespace ops = paddle::operators; - namespace paddle { namespace operators { -template -struct CudaMinFunctor { - inline HOSTDEVICE T operator()(const T* args) const { - return (args[0] > args[1] ? args[1] : args[0]); - } -}; - template class ElementwiseMinKernel : public framework::OpKernel { @@ -38,13 +30,15 @@ class ElementwiseMinKernel int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, CudaMinFunctor()); + cuda_ctx, ins, &outs, axis, MinFunctor()); } }; } // namespace operators } // namespace paddle +namespace ops = paddle::operators; + REGISTER_OP_CUDA_KERNEL( elementwise_min, ops::ElementwiseMinKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h index 5a3e7f90f3c3dbee093c17fd4c5cf863ad1f4d24..648691063c59b237852a364fa002246de7ba6137 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -14,17 +14,13 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" namespace paddle { namespace operators { -template -struct MinFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return a < b ? a : b; } -}; - template class ElementwiseMinKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 33b6f1d60b8de44cb1763ea6f9473b2852c8c601..1a9ac4bd9157fa6fb50b538444666d610398ead8 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -95,10 +94,10 @@ elementwise_mul_grad(const framework::ExecutionContext& ctx, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); auto size = x->numel(); dim3 grid_size = - dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); SimpleElemwiseMulGradCUDAKernel< T><<().stream()>>>( diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index a734f891a9d9e83592156442e48215a93af3a920..80fa430c44307c0b1bc9fe67d211214078655f07 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/cpu_info.h" diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h deleted file mode 100644 index 8344b3d9838b007dc284ffc18d011cbb98808fbc..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h +++ /dev/null @@ -1,231 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/fluid/platform/hostdevice.h" -#ifdef __HIPCC__ -#define PADDLE_CUDA_THREAD_SIZE 256 -#else -#define PADDLE_CUDA_THREAD_SIZE 512 -#endif - -#ifdef PADDLE_WITH_CUDA -#include -#ifdef PADDLE_CUDA_FP16 -#include -#endif -#endif // PADDLE_WITH_CUDA - -#ifdef PADDLE_WITH_HIP -#include -#ifdef PADDLE_CUDA_FP16 -#include -#endif -#endif // PADDLE_WITH_HIP - -#define DIV_ERROR_INFO \ - "InvalidArgumentError: Integer division by zero encountered in divide. " \ - "Please check.\n" -namespace paddle { -namespace operators { - -#define DEFINE_SIMPLE_BINARY_FUNCTOR(Func, expr) \ - template \ - struct Func##Functor { \ - inline HOSTDEVICE T operator()(const T& a, const T& b) const { \ - return a expr b; \ - } \ - }; \ - template \ - struct Inverse##Func##Functor { \ - inline HOSTDEVICE T operator()(const T& a, const T& b) const { \ - return b expr a; \ - } \ - }; - -DEFINE_SIMPLE_BINARY_FUNCTOR(Add, +) -DEFINE_SIMPLE_BINARY_FUNCTOR(Sub, -) -DEFINE_SIMPLE_BINARY_FUNCTOR(Mul, *) -DEFINE_SIMPLE_BINARY_FUNCTOR(Div, /) -#undef DEFINE_SIMPLE_BINARY_FUNCTOR - -// special div functor for int32/int64. check divison has a zero -template -struct DivFunctor::value>::type> { - inline HOSTDEVICE T operator()(const T& a, const T& b) const { - PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO); - return a / b; - } -}; - -#define DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Func, expr) \ - template \ - struct Func##RangeFunctor { \ - Func##RangeFunctor(const T* x, const T* y, T* z) : x_(x), y_(y), z_(z) {} \ - inline HOSTDEVICE void operator()(size_t id) const { \ - z_[id] = x_[id] expr y_[id]; \ - } \ - const T* x_; \ - const T* y_; \ - T* z_; \ - }; -DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Add, +) -DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Sub, -) -DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Mul, *) -DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Div, /) -#undef DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR - -// special div functor for int32/int64. check divison has a zero -template -struct DivRangeFunctor< - T, typename std::enable_if::value>::type> { - DivRangeFunctor(const T* x, const T* y, T* z) : x_(x), y_(y), z_(z) {} - inline HOSTDEVICE void operator()(size_t id) const { - PADDLE_ENFORCE(y_[id] != 0, DIV_ERROR_INFO); - z_[id] = x_[id] / y_[id]; - } - const T* x_; - const T* y_; - T* z_; -}; - -#ifdef PADDLE_CUDA_FP16 -inline DEVICE half2 half2_add(const half2& a, const half2& b) { -#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) - return __hadd2(a, b); -#else - float a1 = __low2float(a); - float a2 = __high2float(a); - float b1 = __low2float(b); - float b2 = __high2float(b); - float r1 = a1 + b1; - float r2 = a2 + b2; - return __floats2half2_rn(r1, r2); -#endif -} - -inline DEVICE half2 half2_sub(const half2& a, const half2& b) { -#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) - return __hsub2(a, b); -#else - float a1 = __low2float(a); - float a2 = __high2float(a); - float b1 = __low2float(b); - float b2 = __high2float(b); - float r1 = a1 - b1; - float r2 = a2 - b2; - return __floats2half2_rn(r1, r2); -#endif -} - -inline DEVICE half2 half2_mul(const half2& a, const half2& b) { -#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) - return __hmul2(a, b); -#else - float a1 = __low2float(a); - float a2 = __high2float(a); - float b1 = __low2float(b); - float b2 = __high2float(b); - float r1 = a1 * b1; - float r2 = a2 * b2; - return __floats2half2_rn(r1, r2); -#endif -} - -inline DEVICE half2 half2_div(const half2& a, const half2& b) { -#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) - return __h2div(a, b); -#else - float a1 = __low2float(a); - float a2 = __high2float(a); - float b1 = __low2float(b); - float b2 = __high2float(b); - float r1 = a1 / b1; - float r2 = a2 / b2; - return __floats2half2_rn(r1, r2); -#endif -} - -#define DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Func, expr, FP16Function) \ - inline __global__ void SameDimsElemwise##Func##CUDAKernel( \ - const float* __restrict__ x, const float* __restrict__ y, float* z, \ - int64_t size) { \ - int tid = blockIdx.x * blockDim.x + threadIdx.x; \ - int stride = gridDim.x * blockDim.x; \ - int loop = size / 4; \ - int remainder = size % 4; \ - const float4* x_vec = reinterpret_cast(x); \ - const float4* y_vec = reinterpret_cast(y); \ - float4* z_vec = reinterpret_cast(z); \ - float4 x_f4, y_f4; \ - for (int i = tid; i < loop; i += stride) { \ - x_f4 = x_vec[i]; \ - y_f4 = y_vec[i]; \ - z_vec[i] = make_float4(x_f4.x expr y_f4.x, x_f4.y expr y_f4.y, \ - x_f4.z expr y_f4.z, x_f4.w expr y_f4.w); \ - } \ - if (tid == loop && remainder != 0) { \ - while (remainder) { \ - int idx = size - remainder; \ - remainder--; \ - z[idx] = x[idx] expr y[idx]; \ - } \ - } \ - } \ - inline __global__ void SameDimsElemwise##Func##CUDAKernel( \ - const half* __restrict__ x, const half* __restrict__ y, half* z, \ - int64_t size) { \ - int tid = blockIdx.x * blockDim.x + threadIdx.x; \ - int stride = gridDim.x * blockDim.x; \ - int loop = size / 8; \ - int remainder = size % 8; \ - const float4* x_vec = reinterpret_cast(x); \ - const float4* y_vec = reinterpret_cast(y); \ - float4* z_vec = reinterpret_cast(z); \ - float4 x_h8, y_h8, z_h8; \ - for (int i = tid; i < loop; i += stride) { \ - x_h8 = x_vec[i]; \ - y_h8 = y_vec[i]; \ - half2* x_h2 = reinterpret_cast(&x_h8); \ - half2* y_h2 = reinterpret_cast(&y_h8); \ - half2* z_h2 = reinterpret_cast(&z_h8); \ - z_h2[0] = FP16Function(x_h2[0], y_h2[0]); \ - z_h2[1] = FP16Function(x_h2[1], y_h2[1]); \ - z_h2[2] = FP16Function(x_h2[2], y_h2[2]); \ - z_h2[3] = FP16Function(x_h2[3], y_h2[3]); \ - z_vec[i] = z_h8; \ - } \ - if (tid == loop && remainder != 0) { \ - while (remainder) { \ - int idx = size - remainder; \ - remainder--; \ - z[idx] = __float2half(__half2float(x[idx]) expr __half2float(y[idx])); \ - } \ - } \ - } -DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Add, +, half2_add) -DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Sub, -, half2_sub) -DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Mul, *, half2_mul) -DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Div, /, half2_div) -#undef DEFINE_SIMPLE_CUDA_BINARY_KERNEL - -#endif // PADDLE_CUDA_FP16 - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 7bbfefba20fa7572d2756bba8b803d2fcc7f8682..312978a010b30ceb394d25269d9f4daa804702e0 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -25,7 +25,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" +#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/transform.h" diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 2643cc0e7a27547825c56483de3f498d6cf57751..38465df243032e260abf97c0d9bc47495b627660 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" #include "paddle/fluid/platform/complex.h" @@ -59,10 +60,10 @@ elementwise_sub_grad(const framework::ExecutionContext& ctx, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); auto size = x->numel(); dim3 grid_size = - dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); SimpleElemwiseSubGradCUDAKernel< T><<().stream()>>>( diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h index 426093413276092538c67676abb2c1e9b7f637ed..fa26722266a637b71dafa6e6b1de537438f9fa7f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -11,10 +11,10 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #pragma once #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index c9ba37d0008ba264b1bf9d6281b6888aa369a791..970ec69412026f384ddf853c9638b3a475a26413 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -19,7 +19,6 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" #if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index f266aa0cba0c96a9821dc0955cc4ead86d31056b..b0c361e86a531730a4d9999682a6ba784b6d415c 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -25,7 +25,6 @@ #include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/functors.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/for_range.h"