From 6a1ddd61718812eb3fbae8eb51a97d7299c469a3 Mon Sep 17 00:00:00 2001 From: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com> Date: Tue, 12 Apr 2022 16:18:59 +0800 Subject: [PATCH] [cherry pick]fix paddle tensor numel check (#41665) --- paddle/fluid/platform/device/gpu/gpu_launch_config.h | 9 +++++---- paddle/phi/backends/gpu/gpu_launch_config.h | 12 ++++++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/platform/device/gpu/gpu_launch_config.h b/paddle/fluid/platform/device/gpu/gpu_launch_config.h index 4a550e61d42..6e1f324625d 100644 --- a/paddle/fluid/platform/device/gpu/gpu_launch_config.h +++ b/paddle/fluid/platform/device/gpu/gpu_launch_config.h @@ -99,10 +99,11 @@ struct GpuLaunchConfig { inline GpuLaunchConfig GetGpuLaunchConfig1D( const platform::CUDADeviceContext& context, int64_t numel, int vec_size = 1) { - PADDLE_ENFORCE_GT(numel, 0, platform::errors::InvalidArgument( - "element quantity should be greater than 0," - " but received value is: %d.", - numel)); + PADDLE_ENFORCE_GE(numel, 0, + platform::errors::InvalidArgument( + "element quantity should be greater than or equal to 0," + " but received value is: %d.", + numel)); // Get compute_capability const int capability = context.GetComputeCapability(); /* If thread number per block is 64/128/256/512, cuda performs better.*/ diff --git a/paddle/phi/backends/gpu/gpu_launch_config.h b/paddle/phi/backends/gpu/gpu_launch_config.h index ea54083e817..70466e2d753 100644 --- a/paddle/phi/backends/gpu/gpu_launch_config.h +++ b/paddle/phi/backends/gpu/gpu_launch_config.h @@ -101,12 +101,12 @@ struct GpuLaunchConfig { inline GpuLaunchConfig GetGpuLaunchConfig1D(const phi::GPUContext& context, int64_t numel, int vec_size = 1) { - PADDLE_ENFORCE_GT( - numel, - 0, - phi::errors::InvalidArgument("element quantity should be greater than 0," - " but received value is: %d.", - numel)); + PADDLE_ENFORCE_GE(numel, + 0, + phi::errors::InvalidArgument( + "element quantity should be greater than or equal to 0," + " but received value is: %d.", + numel)); // Get compute_capability const int capability = context.GetComputeCapability(); /* If thread number per block is 64/128/256/512, cuda performs better.*/ -- GitLab