From 2367cca6ca768ae4d916a4d1fcf21e04a56cb108 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Wed, 15 Sep 2021 14:23:50 +0800 Subject: [PATCH] Unify the functor definition of elementwise add, sub, mul, div, floordiv, max, min. (#35684) --- .../elementwise/elementwise_add_op.cu | 6 +- .../elementwise/elementwise_add_op.h | 17 +- .../elementwise/elementwise_div_op.cu | 25 +- .../elementwise/elementwise_div_op.h | 1 - .../elementwise/elementwise_floordiv_op.cu | 19 +- .../elementwise/elementwise_floordiv_op.h | 43 +--- .../elementwise/elementwise_functor.h | 117 +++++++++ .../elementwise/elementwise_max_op.cu | 14 +- .../elementwise/elementwise_max_op.h | 6 +- .../elementwise/elementwise_min_op.cu | 14 +- .../elementwise/elementwise_min_op.h | 6 +- .../elementwise/elementwise_mul_op.cu | 5 +- .../elementwise/elementwise_mul_op.h | 2 +- .../elementwise/elementwise_op_function.cu.h | 231 ------------------ .../elementwise/elementwise_op_function.h | 2 +- .../elementwise/elementwise_sub_op.cu | 5 +- .../elementwise/elementwise_sub_op.h | 2 +- paddle/fluid/operators/layer_norm_op.h | 1 - paddle/fluid/operators/svd_helper.h | 1 - 19 files changed, 149 insertions(+), 368 deletions(-) create mode 100644 paddle/fluid/operators/elementwise/elementwise_functor.h delete mode 100644 paddle/fluid/operators/elementwise/elementwise_op_function.cu.h diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu index bd91142882f..331867617bd 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cu @@ -147,10 +147,10 @@ elementwise_add_grad(const framework::ExecutionContext& ctx, } else if (dx_data != dout_data && dy_data != dout_data) { auto size = x->numel(); int vec_size = max(static_cast(sizeof(float4) / sizeof(T)), 1); - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); dim3 grid_size = - dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) / - PADDLE_CUDA_THREAD_SIZE, + dim3(((size + vec_size - 1) / vec_size + ELEMENTWISE_BLOCK_SIZE - 1) / + ELEMENTWISE_BLOCK_SIZE, 1); SimpleElemwiseAddGradCUDAKernel< T><< #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -#ifdef __NVCC__ -#include -#include -#include "cub/cub.cuh" - -#endif -#ifdef __HIPCC__ -#include -#include -#include -namespace cub = hipcub; -#endif -#endif - namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cu b/paddle/fluid/operators/elementwise/elementwise_div_op.cu index 8853fd609f7..ce487f284d9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cu @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "paddle/fluid/operators/elementwise/elementwise_div_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/platform/complex.h" @@ -22,24 +23,6 @@ namespace plat = paddle::platform; namespace paddle { namespace operators { -template -struct CudaDivFunctor { - inline HOSTDEVICE T operator()(const T* args) const { - return args[0] / args[1]; - } -}; - -template -struct CudaDivFunctor::value>> { - inline HOSTDEVICE T operator()(const T* args) const { - PADDLE_ENFORCE(args[1] != 0, - "Invalid Argument Error: Integer division by zero " - "encountered in divide. Please check the input value."); - return args[0] / args[1]; - } -}; - template class ElementwiseDivKernel : public framework::OpKernel { @@ -52,7 +35,7 @@ class ElementwiseDivKernel int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, CudaDivFunctor()); + cuda_ctx, ins, &outs, axis, DivFunctor()); } }; @@ -124,10 +107,10 @@ elementwise_div_grad(const framework::ExecutionContext& ctx, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); auto size = x->numel(); dim3 grid_size = - dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); SimpleElemwiseDivGradCUDAKernel< T><<().stream()>>>( diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index a0b9633acb2..0ec42e54e14 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -18,7 +18,6 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" #include "paddle/fluid/operators/math/blas.h" diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu index a0510d95700..41a0ae00f27 100644 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.cu @@ -11,25 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "paddle/fluid/operators/elementwise/elementwise_floordiv_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -namespace ops = paddle::operators; -namespace plat = paddle::platform; - namespace paddle { namespace operators { -template -struct CudaFloorDivFunctor { - inline HOSTDEVICE T operator()(const T argv[]) const { - PADDLE_ENFORCE(argv[1] != 0, - "InvalidArgument: divide by zero " - "encountered in floor-divide ops, please check.\n"); - return static_cast(std::trunc(argv[0] / argv[1])); - } -}; - template class ElementwiseFloorDivKernel : public framework::OpKernel { @@ -42,13 +30,16 @@ class ElementwiseFloorDivKernel int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, CudaFloorDivFunctor()); + cuda_ctx, ins, &outs, axis, FloorDivFunctor()); } }; } // namespace operators } // namespace paddle +namespace ops = paddle::operators; +namespace plat = paddle::platform; + REGISTER_OP_CUDA_KERNEL( elementwise_floordiv, ops::ElementwiseFloorDivKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h index bc3c2994c84..ae8d2d8625c 100644 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h @@ -15,54 +15,13 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { -template -struct FloorDivFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { -#if defined(__HIPCC__) || defined(__CUDA_ARCH__) - if (b == 0) { - printf("Error: Divide by zero encounter in floor_divide\n"); -#ifdef __HIPCC__ - abort(); -#else - asm("trap;"); -#endif - } -#else - if (b == 0) - PADDLE_THROW(platform::errors::InvalidArgument( - "Divide by zero encounter in floor_divide")); -#endif - return static_cast(std::trunc(a / b)); - } -}; - -template -struct InverseFloorDivFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { -#if defined(__HIPCC__) || defined(__CUDA_ARCH__) - if (a == 0) { - printf("Error: Divide by zero encounter in floor_divide\n"); -#ifdef __HIPCC__ - abort(); -#else - asm("trap;"); -#endif - } -#else - if (a == 0) - PADDLE_THROW(platform::errors::InvalidArgument( - "Divide by zero encounter in floor_divide")); -#endif - return static_cast(std::trunc(b / a)); - } -}; - template void elementwise_floor_div(const framework::ExecutionContext &ctx, const framework::Tensor *x, diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h new file mode 100644 index 00000000000..abac43a2616 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -0,0 +1,117 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +// Define the binary functors used in elementwise ops. + +// Add +template +struct AddFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a + b; } +}; +template +struct InverseAddFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b + a; } +}; + +// Subtract +template +struct SubFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a - b; } +}; +template +struct InverseSubFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b - a; } +}; + +// Multiply +template +struct MulFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a * b; } +}; +template +struct InverseMulFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b * a; } +}; + +// Divide +#define DIV_ERROR_INFO \ + "InvalidArgumentError: Integer division by zero encountered in " \ + "(floor) divide. Please check the input value." + +template +struct DivFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return a / b; } +}; + +template +struct DivFunctor::value>::type> { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + // For int32/int64, need to check whether the divison is zero. + PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO); + return a / b; + } +}; + +template +struct InverseDivFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { return b / a; } +}; + +// Floor Divide +template +struct FloorDivFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO); + return static_cast(std::trunc(a / b)); + } +}; + +template +struct InverseFloorDivFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + PADDLE_ENFORCE(a != 0, DIV_ERROR_INFO); + return static_cast(std::trunc(b / a)); + } +}; + +#undef DIV_ERROR_INFO + +// Maximum +template +struct MaxFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + return a > b ? a : b; + } +}; + +// Minmum +template +struct MinFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + return a < b ? a : b; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index d4b5d98d5b0..9657e1896e3 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -11,21 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "paddle/fluid/operators/elementwise/elementwise_max_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -namespace ops = paddle::operators; - namespace paddle { namespace operators { -template -struct CudaMaxFunctor { - inline HOSTDEVICE T operator()(const T* args) const { - return (args[0] > args[1] ? args[0] : args[1]); - } -}; - template class ElementwiseMaxKernel : public framework::OpKernel { @@ -38,13 +30,15 @@ class ElementwiseMaxKernel int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, CudaMaxFunctor()); + cuda_ctx, ins, &outs, axis, MaxFunctor()); } }; } // namespace operators } // namespace paddle +namespace ops = paddle::operators; + REGISTER_OP_CUDA_KERNEL( elementwise_max, ops::ElementwiseMaxKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index abdb1b9671d..8ee8fe923a8 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -14,17 +14,13 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" namespace paddle { namespace operators { -template -struct MaxFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return a > b ? a : b; } -}; - template class ElementwiseMaxKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index 4a99f7e3670..eed6f72b04f 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -11,21 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "paddle/fluid/operators/elementwise/elementwise_min_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -namespace ops = paddle::operators; - namespace paddle { namespace operators { -template -struct CudaMinFunctor { - inline HOSTDEVICE T operator()(const T* args) const { - return (args[0] > args[1] ? args[1] : args[0]); - } -}; - template class ElementwiseMinKernel : public framework::OpKernel { @@ -38,13 +30,15 @@ class ElementwiseMinKernel int axis = PackTensorsIntoVector(ctx, &ins, &outs); LaunchElementwiseCudaKernel( - cuda_ctx, ins, &outs, axis, CudaMinFunctor()); + cuda_ctx, ins, &outs, axis, MinFunctor()); } }; } // namespace operators } // namespace paddle +namespace ops = paddle::operators; + REGISTER_OP_CUDA_KERNEL( elementwise_min, ops::ElementwiseMinKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h index 5a3e7f90f3c..648691063c5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -14,17 +14,13 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" namespace paddle { namespace operators { -template -struct MinFunctor { - inline HOSTDEVICE T operator()(T a, T b) const { return a < b ? a : b; } -}; - template class ElementwiseMinKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 33b6f1d60b8..1a9ac4bd915 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/platform/complex.h" #include "paddle/fluid/platform/float16.h" @@ -95,10 +94,10 @@ elementwise_mul_grad(const framework::ExecutionContext& ctx, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); auto size = x->numel(); dim3 grid_size = - dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); SimpleElemwiseMulGradCUDAKernel< T><<().stream()>>>( diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index a734f891a9d..80fa430c443 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/cpu_info.h" diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h b/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h deleted file mode 100644 index 8344b3d9838..00000000000 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.cu.h +++ /dev/null @@ -1,231 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/fluid/platform/hostdevice.h" -#ifdef __HIPCC__ -#define PADDLE_CUDA_THREAD_SIZE 256 -#else -#define PADDLE_CUDA_THREAD_SIZE 512 -#endif - -#ifdef PADDLE_WITH_CUDA -#include -#ifdef PADDLE_CUDA_FP16 -#include -#endif -#endif // PADDLE_WITH_CUDA - -#ifdef PADDLE_WITH_HIP -#include -#ifdef PADDLE_CUDA_FP16 -#include -#endif -#endif // PADDLE_WITH_HIP - -#define DIV_ERROR_INFO \ - "InvalidArgumentError: Integer division by zero encountered in divide. " \ - "Please check.\n" -namespace paddle { -namespace operators { - -#define DEFINE_SIMPLE_BINARY_FUNCTOR(Func, expr) \ - template \ - struct Func##Functor { \ - inline HOSTDEVICE T operator()(const T& a, const T& b) const { \ - return a expr b; \ - } \ - }; \ - template \ - struct Inverse##Func##Functor { \ - inline HOSTDEVICE T operator()(const T& a, const T& b) const { \ - return b expr a; \ - } \ - }; - -DEFINE_SIMPLE_BINARY_FUNCTOR(Add, +) -DEFINE_SIMPLE_BINARY_FUNCTOR(Sub, -) -DEFINE_SIMPLE_BINARY_FUNCTOR(Mul, *) -DEFINE_SIMPLE_BINARY_FUNCTOR(Div, /) -#undef DEFINE_SIMPLE_BINARY_FUNCTOR - -// special div functor for int32/int64. check divison has a zero -template -struct DivFunctor::value>::type> { - inline HOSTDEVICE T operator()(const T& a, const T& b) const { - PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO); - return a / b; - } -}; - -#define DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Func, expr) \ - template \ - struct Func##RangeFunctor { \ - Func##RangeFunctor(const T* x, const T* y, T* z) : x_(x), y_(y), z_(z) {} \ - inline HOSTDEVICE void operator()(size_t id) const { \ - z_[id] = x_[id] expr y_[id]; \ - } \ - const T* x_; \ - const T* y_; \ - T* z_; \ - }; -DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Add, +) -DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Sub, -) -DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Mul, *) -DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR(Div, /) -#undef DEFINE_SIMPLE_CUDA_BINARY_FUNCTOR - -// special div functor for int32/int64. check divison has a zero -template -struct DivRangeFunctor< - T, typename std::enable_if::value>::type> { - DivRangeFunctor(const T* x, const T* y, T* z) : x_(x), y_(y), z_(z) {} - inline HOSTDEVICE void operator()(size_t id) const { - PADDLE_ENFORCE(y_[id] != 0, DIV_ERROR_INFO); - z_[id] = x_[id] / y_[id]; - } - const T* x_; - const T* y_; - T* z_; -}; - -#ifdef PADDLE_CUDA_FP16 -inline DEVICE half2 half2_add(const half2& a, const half2& b) { -#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) - return __hadd2(a, b); -#else - float a1 = __low2float(a); - float a2 = __high2float(a); - float b1 = __low2float(b); - float b2 = __high2float(b); - float r1 = a1 + b1; - float r2 = a2 + b2; - return __floats2half2_rn(r1, r2); -#endif -} - -inline DEVICE half2 half2_sub(const half2& a, const half2& b) { -#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) - return __hsub2(a, b); -#else - float a1 = __low2float(a); - float a2 = __high2float(a); - float b1 = __low2float(b); - float b2 = __high2float(b); - float r1 = a1 - b1; - float r2 = a2 - b2; - return __floats2half2_rn(r1, r2); -#endif -} - -inline DEVICE half2 half2_mul(const half2& a, const half2& b) { -#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) - return __hmul2(a, b); -#else - float a1 = __low2float(a); - float a2 = __high2float(a); - float b1 = __low2float(b); - float b2 = __high2float(b); - float r1 = a1 * b1; - float r2 = a2 * b2; - return __floats2half2_rn(r1, r2); -#endif -} - -inline DEVICE half2 half2_div(const half2& a, const half2& b) { -#if defined(__HIPCC__) || (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530) - return __h2div(a, b); -#else - float a1 = __low2float(a); - float a2 = __high2float(a); - float b1 = __low2float(b); - float b2 = __high2float(b); - float r1 = a1 / b1; - float r2 = a2 / b2; - return __floats2half2_rn(r1, r2); -#endif -} - -#define DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Func, expr, FP16Function) \ - inline __global__ void SameDimsElemwise##Func##CUDAKernel( \ - const float* __restrict__ x, const float* __restrict__ y, float* z, \ - int64_t size) { \ - int tid = blockIdx.x * blockDim.x + threadIdx.x; \ - int stride = gridDim.x * blockDim.x; \ - int loop = size / 4; \ - int remainder = size % 4; \ - const float4* x_vec = reinterpret_cast(x); \ - const float4* y_vec = reinterpret_cast(y); \ - float4* z_vec = reinterpret_cast(z); \ - float4 x_f4, y_f4; \ - for (int i = tid; i < loop; i += stride) { \ - x_f4 = x_vec[i]; \ - y_f4 = y_vec[i]; \ - z_vec[i] = make_float4(x_f4.x expr y_f4.x, x_f4.y expr y_f4.y, \ - x_f4.z expr y_f4.z, x_f4.w expr y_f4.w); \ - } \ - if (tid == loop && remainder != 0) { \ - while (remainder) { \ - int idx = size - remainder; \ - remainder--; \ - z[idx] = x[idx] expr y[idx]; \ - } \ - } \ - } \ - inline __global__ void SameDimsElemwise##Func##CUDAKernel( \ - const half* __restrict__ x, const half* __restrict__ y, half* z, \ - int64_t size) { \ - int tid = blockIdx.x * blockDim.x + threadIdx.x; \ - int stride = gridDim.x * blockDim.x; \ - int loop = size / 8; \ - int remainder = size % 8; \ - const float4* x_vec = reinterpret_cast(x); \ - const float4* y_vec = reinterpret_cast(y); \ - float4* z_vec = reinterpret_cast(z); \ - float4 x_h8, y_h8, z_h8; \ - for (int i = tid; i < loop; i += stride) { \ - x_h8 = x_vec[i]; \ - y_h8 = y_vec[i]; \ - half2* x_h2 = reinterpret_cast(&x_h8); \ - half2* y_h2 = reinterpret_cast(&y_h8); \ - half2* z_h2 = reinterpret_cast(&z_h8); \ - z_h2[0] = FP16Function(x_h2[0], y_h2[0]); \ - z_h2[1] = FP16Function(x_h2[1], y_h2[1]); \ - z_h2[2] = FP16Function(x_h2[2], y_h2[2]); \ - z_h2[3] = FP16Function(x_h2[3], y_h2[3]); \ - z_vec[i] = z_h8; \ - } \ - if (tid == loop && remainder != 0) { \ - while (remainder) { \ - int idx = size - remainder; \ - remainder--; \ - z[idx] = __float2half(__half2float(x[idx]) expr __half2float(y[idx])); \ - } \ - } \ - } -DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Add, +, half2_add) -DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Sub, -, half2_sub) -DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Mul, *, half2_mul) -DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Div, /, half2_div) -#undef DEFINE_SIMPLE_CUDA_BINARY_KERNEL - -#endif // PADDLE_CUDA_FP16 - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index 7bbfefba20f..312978a010b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -25,7 +25,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" +#include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/transform.h" diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu index 2643cc0e7a2..38465df2430 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cu @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_sub_op.h" #include "paddle/fluid/platform/complex.h" @@ -59,10 +60,10 @@ elementwise_sub_grad(const framework::ExecutionContext& ctx, const framework::Tensor* out, const framework::Tensor* dout, framework::Tensor* dx, framework::Tensor* dy) { - dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1); + dim3 block_size = dim3(ELEMENTWISE_BLOCK_SIZE, 1); auto size = x->numel(); dim3 grid_size = - dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1); + dim3((size + ELEMENTWISE_BLOCK_SIZE - 1) / ELEMENTWISE_BLOCK_SIZE, 1); SimpleElemwiseSubGradCUDAKernel< T><<().stream()>>>( diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.h b/paddle/fluid/operators/elementwise/elementwise_sub_op.h index 42609341327..fa26722266a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.h @@ -11,10 +11,10 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #pragma once #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" diff --git a/paddle/fluid/operators/layer_norm_op.h b/paddle/fluid/operators/layer_norm_op.h index c9ba37d0008..970ec694120 100644 --- a/paddle/fluid/operators/layer_norm_op.h +++ b/paddle/fluid/operators/layer_norm_op.h @@ -19,7 +19,6 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" #if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \ diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index f266aa0cba0..b0c361e86a5 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -25,7 +25,6 @@ #include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/functors.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/for_range.h" -- GitLab