diff --git a/paddle/fluid/operators/bce_loss_op.cu b/paddle/fluid/operators/bce_loss_op.cu index 73f73a81c088eb4e383dc40b206f2494605ee9eb..18562b243255be9077cb2310de3d7f1d4857e969 100644 --- a/paddle/fluid/operators/bce_loss_op.cu +++ b/paddle/fluid/operators/bce_loss_op.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include #include "paddle/fluid/operators/bce_loss_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/math.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" @@ -23,6 +24,17 @@ namespace operators { using Tensor = framework::Tensor; +template +struct BCELossGradFunctor { + T one = static_cast(1.0f); + T eps = static_cast(1e-12); + __device__ __forceinline__ T operator()(const T& x, const T& label, + const T& dout) const { + T term1 = max((one - x) * x, eps); + return (dout * (x - label) / term1); + } +}; + template __global__ void GPUBCELossForward(const T* x_data, const T* label_data, T* out_data, const int in_numel) { @@ -44,23 +56,6 @@ __global__ void GPUBCELossForward(const T* x_data, const T* label_data, } } -template -__global__ void GPUBCELossBackward(const T* x_data, const T* label_data, - const T* dout_data, T* dx_data, - const int in_numel) { - CUDA_KERNEL_LOOP(i, in_numel) { - T x = x_data[i]; - T label = label_data[i]; - T dout = dout_data[i]; - T one = static_cast(1.); - T eps = static_cast(1e-12); - - T term1 = max((one - x) * x, eps); - - dx_data[i] = dout * (x - label) / term1; - } -} - template class BCELossCUDAKernel : public framework::OpKernel { public: @@ -91,17 +86,13 @@ class BCELossGradCUDAKernel : public framework::OpKernel { auto* labels = ctx.Input("Label"); auto* dout = ctx.Input(framework::GradVarName("Out")); auto* dx = ctx.Output(framework::GradVarName("X")); - - int x_numel = x->numel(); - auto* dx_data = dx->mutable_data(ctx.GetPlace()); - - auto& dev_ctx = ctx.cuda_device_context(); - platform::GpuLaunchConfig config = - platform::GetGpuLaunchConfig1D(dev_ctx, x_numel); - - GPUBCELossBackward<<>>( - x->data(), labels->data(), dout->data(), dx_data, x_numel); + dx->mutable_data(ctx.GetPlace()); + std::vector ins = {x, labels, dout}; + std::vector outs = {dx}; + auto& dev_ctx = ctx.template device_context(); + auto functor = BCELossGradFunctor(); + LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); } }; diff --git a/paddle/fluid/operators/clip_op.h b/paddle/fluid/operators/clip_op.h index 116edc0390e46daff29db69fe19f82d5f968bc4c..f08a7b2d573145ecc057f70aa5c8c02465746da8 100644 --- a/paddle/fluid/operators/clip_op.h +++ b/paddle/fluid/operators/clip_op.h @@ -18,6 +18,9 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/transform.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#endif namespace paddle { namespace operators { @@ -25,17 +28,6 @@ namespace operators { using framework::Tensor; using platform::Transform; -#if defined(__NVCC__) || defined(__HIPCC__) -template -__global__ void ClipCudaKernel(const T* input, T* out, int num, - UnaryOperation op) { - int idx = threadIdx.x + blockDim.x * blockIdx.x; - if (idx < num) { - out[idx] = op(input[idx]); - } -} -#endif - template class ClipFunctor { public: @@ -106,12 +98,12 @@ class ClipKernel : public framework::OpKernel { int64_t numel = x->numel(); if (platform::is_gpu_place(context.GetPlace())) { #if defined(__NVCC__) || defined(__HIPCC__) - int threads = 256; - int blocks = (numel + threads - 1) / threads; - ClipCudaKernel><<< - blocks, threads, 0, - context.template device_context() - .stream()>>>(x_data, out_data, numel, ClipFunctor(min, max)); + std::vector ins = {x}; + std::vector outs = {out}; + auto functor = ClipFunctor(min, max); + LaunchSameDimsElementwiseCudaKernel( + context.template device_context(), ins, + &outs, functor); #endif } else { Transform trans; diff --git a/paddle/fluid/operators/label_smooth_op.cu b/paddle/fluid/operators/label_smooth_op.cu index c94a37f03f2b729d410fb1a2f04af30276482463..2e7d1de3bd756d02fa37ee14de15879c01b64385 100644 --- a/paddle/fluid/operators/label_smooth_op.cu +++ b/paddle/fluid/operators/label_smooth_op.cu @@ -13,19 +13,39 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/label_smooth_op.h" namespace paddle { namespace operators { template -__global__ void LabelSmoothRunOriginKernel(const int N, const float epsilon, - const int label_dim, const T* src, - T* dst) { - CUDA_KERNEL_LOOP(idx, N) { - dst[idx] = static_cast(1 - epsilon) * src[idx] + - static_cast(epsilon / label_dim); +struct LabelSmoothFunctor { + T epsilon; + T label_dim; + + __forceinline__ LabelSmoothFunctor(float epsilon_data, int label_dim_data) { + epsilon = static_cast(epsilon_data); + label_dim = static_cast(label_dim_data); } -} + + __device__ __forceinline__ T operator()(const T& x) const { + return (static_cast(1 - epsilon) * x + + static_cast(epsilon / label_dim)); + } +}; + +template +struct LabelSmoothGradFunctor { + T epsilon; + + __forceinline__ LabelSmoothGradFunctor(float epsilon_data) { + epsilon = static_cast(epsilon_data); + } + + __device__ __forceinline__ T operator()(const T& x) const { + return static_cast(1 - epsilon) * x; + } +}; template __global__ void LabelSmoothRunDistKernel(const int N, const float epsilon, @@ -38,14 +58,6 @@ __global__ void LabelSmoothRunDistKernel(const int N, const float epsilon, } } -template -__global__ void LabelSmoothGradRunKernel(const int N, const float epsilon, - const T* src, T* dst) { - CUDA_KERNEL_LOOP(idx, N) { - dst[idx] = static_cast(1 - epsilon) * src[idx]; - } -} - template class LabelSmoothGPUKernel : public framework::OpKernel { public: @@ -69,8 +81,14 @@ class LabelSmoothGPUKernel : public framework::OpKernel { size_prob, epsilon, dist_numel, in_data, dist_data, out_data); } else { - LabelSmoothRunOriginKernel<<>>( - size_prob, epsilon, label_dim, in_data, out_data); + auto& dev_ctx = + ctx.template device_context(); + + std::vector ins = {in_t}; + std::vector outs = {out_t}; + auto functor = LabelSmoothFunctor(epsilon, label_dim); + LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); } } }; @@ -84,15 +102,13 @@ class LabelSmoothGradGPUKernel : public framework::OpKernel { d_in_t->mutable_data(ctx.GetPlace()); auto epsilon = ctx.Attr("epsilon"); - auto& dev = *ctx.template device_context().eigen_device(); - const T* in_data = d_out_t->data(); - auto size_prob = d_out_t->numel(); - T* out_data = d_in_t->mutable_data(ctx.GetPlace()); - int threads = 512; - int grid = (size_prob + threads - 1) / threads; - auto stream = ctx.cuda_device_context().stream(); - LabelSmoothGradRunKernel<<>>( - size_prob, epsilon, in_data, out_data); + auto& dev_ctx = ctx.template device_context(); + + std::vector ins = {d_out_t}; + std::vector outs = {d_in_t}; + auto functor = LabelSmoothGradFunctor(epsilon); + LaunchSameDimsElementwiseCudaKernel( + dev_ctx, ins, &outs, functor); } }; } // namespace operators diff --git a/paddle/pten/kernels/gpu/cast_kernel.cu b/paddle/pten/kernels/gpu/cast_kernel.cu index 58adbcc6f3599df56301eb9668467df458785eee..9f65400f93b9f0cbac0e8aae41fa4678b52a8bfa 100644 --- a/paddle/pten/kernels/gpu/cast_kernel.cu +++ b/paddle/pten/kernels/gpu/cast_kernel.cu @@ -19,6 +19,7 @@ #include "paddle/pten/core/kernel_registry.h" // See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/device/gpu/gpu_helper.h" @@ -27,62 +28,24 @@ namespace pten { -template -__global__ void VecCastCUDAKernel(const InT* in, const int64_t N, OutT* out) { - using LoadT = paddle::platform::AlignedVector; - using StoreT = paddle::platform::AlignedVector; - - int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; - for (int64_t i = idx * VecSize; i < N; - i += blockDim.x * gridDim.x * VecSize) { - LoadT in_val; - paddle::platform::Load(&in[i], &in_val); - - StoreT out_val; -#pragma unroll - for (int j = 0; j < VecSize; j++) { - out_val[j] = static_cast(in_val[j]); - } - - paddle::platform::Store(out_val, &out[i]); - } -} - template -__global__ void CastCUDAKernel(const InT* in, const int64_t N, OutT* out) { - CUDA_KERNEL_LOOP(index, N) { out[index] = static_cast(in[index]); } -} - -template -void CastCUDAKernelImplWithPtr(const GPUContext& dev_ctx, - const InT* in_data, - OutT* out_data, - int64_t size) { - paddle::platform::GpuLaunchConfig config = - paddle::platform::GetGpuLaunchConfig1D(dev_ctx, size); - int vec_size = paddle::platform::GetVectorizedSize(out_data); - if (!std::is_same::value && vec_size == 4 && size % 4 == 0) { - VecCastCUDAKernel<<>>( - in_data, size, out_data); - } else { - CastCUDAKernel<<>>(in_data, size, out_data); +struct CastFuctor { + __device__ __forceinline__ OutT operator()(const InT& x) const { + return static_cast(x); } -} +}; template void CastCUDAKernelImpl(const GPUContext& dev_ctx, const DenseTensor& x, DenseTensor* out) { - auto* in_data = x.data(); - auto size = x.numel(); - auto* out_data = out->mutable_data(); - CastCUDAKernelImplWithPtr(dev_ctx, in_data, out_data, size); + std::vector inputs; + std::vector outputs; + inputs.emplace_back(&x); + outputs.emplace_back(out); + out->mutable_data(); + LaunchSameDimsElementwiseCudaKernel( + dev_ctx, inputs, &outputs, CastFuctor()); } template diff --git a/paddle/pten/kernels/gpu/scale_kernel.cu b/paddle/pten/kernels/gpu/scale_kernel.cu index e67fd4cfdccb3c3bfd7e917cf2dbf6be166ddffd..f4bb5c5dbf75502bfad09c987d3da22864d99403 100644 --- a/paddle/pten/kernels/gpu/scale_kernel.cu +++ b/paddle/pten/kernels/gpu/scale_kernel.cu @@ -16,11 +16,54 @@ limitations under the License. */ #include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/core/kernel_registry.h" -#include "paddle/pten/kernels/impl/scale_kernel_impl.h" - // See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/platform/float16.h" +namespace pten { + +template +struct ScaleFunctor { + InT bias; + InT scale; + bool bias_after_scale; + + ScaleFunctor(InT scale_data, InT bias_data, bool is_bias_after_sacle) { + scale = scale_data; + bias = bias_data; + bias_after_scale = is_bias_after_sacle; + } + + __device__ __forceinline__ InT operator()(const InT& x) const { + if (bias_after_scale) { + return scale * x + bias; + } else { + return scale * (x + bias); + } + } +}; + +template +void Scale(const ContextT& dev_ctx, + const DenseTensor& x, + const Scalar& scale, + float bias, + bool bias_after_scale, + DenseTensor* out) { + std::vector inputs; + std::vector outputs; + inputs.emplace_back(&x); + outputs.emplace_back(out); + out->mutable_data(); + LaunchSameDimsElementwiseCudaKernel( + dev_ctx, + inputs, + &outputs, + ScaleFunctor(scale.to(), static_cast(bias), bias_after_scale)); +} + +} // namespace pten + PT_REGISTER_CTX_KERNEL(scale, GPU, ALL_LAYOUT,