From 51cae7f78a6ed5af750ea49f84852a064396a0f9 Mon Sep 17 00:00:00 2001 From: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com> Date: Tue, 12 Apr 2022 10:13:56 +0800 Subject: [PATCH] fix_paddle_numel_check (#41607) * fix_paddle_numel_check * fix_paddle_numel_check --- paddle/fluid/platform/device/gpu/gpu_launch_config.h | 9 +++++---- paddle/phi/backends/gpu/gpu_launch_config.h | 12 ++++++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/platform/device/gpu/gpu_launch_config.h b/paddle/fluid/platform/device/gpu/gpu_launch_config.h index 4a550e61d4..80d60ca95b 100644 --- a/paddle/fluid/platform/device/gpu/gpu_launch_config.h +++ b/paddle/fluid/platform/device/gpu/gpu_launch_config.h @@ -99,10 +99,11 @@ struct GpuLaunchConfig { inline GpuLaunchConfig GetGpuLaunchConfig1D( const platform::CUDADeviceContext& context, int64_t numel, int vec_size = 1) { - PADDLE_ENFORCE_GT(numel, 0, platform::errors::InvalidArgument( - "element quantity should be greater than 0," - " but received value is: %d.", - numel)); + PADDLE_ENFORCE_GE(numel, 0, + platform::errors::InvalidArgument( + "element quantity should be greater than or equal 0," + " but received value is: %d.", + numel)); // Get compute_capability const int capability = context.GetComputeCapability(); /* If thread number per block is 64/128/256/512, cuda performs better.*/ diff --git a/paddle/phi/backends/gpu/gpu_launch_config.h b/paddle/phi/backends/gpu/gpu_launch_config.h index ea54083e81..888b44632e 100644 --- a/paddle/phi/backends/gpu/gpu_launch_config.h +++ b/paddle/phi/backends/gpu/gpu_launch_config.h @@ -101,12 +101,12 @@ struct GpuLaunchConfig { inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, int64_t numel, int vec_size = 1) { - PADDLE_ENFORCE_GT( - numel, - 0, - phi::errors::InvalidArgument("element quantity should be greater than 0," - " but received value is: %d.", - numel)); + PADDLE_ENFORCE_GE(numel, + 0, + phi::errors::InvalidArgument( + "element quantity should be greater than or equal 0," + " but received value is: %d.", + numel)); // Get compute_capability const int capability = context.GetComputeCapability(); /* If thread number per block is 64/128/256/512, cuda performs better.*/ -- GitLab