diff --git a/mindspore/ccsrc/device/gpu/cuda_common.h b/mindspore/ccsrc/device/gpu/cuda_common.h index 5a5b6416cee4285e089f82aca5124419a5dbf54b..b79ba8bc2815345b5b155a3142472e09f9a89e5e 100644 --- a/mindspore/ccsrc/device/gpu/cuda_common.h +++ b/mindspore/ccsrc/device/gpu/cuda_common.h @@ -56,7 +56,8 @@ class CudaCommon { #define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads) #define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num() #define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm() -#define MINIUM_SM 7 +#define MINIUM_SM 6 +#define RECOMMEND_SM 7 } // namespace gpu } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc index fba2b24512f24ab7c519a46a4effae1df41acbd2..e38cc02e23931e5c369ab0c533bdbca1f0d1f784 100644 --- a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc +++ b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc @@ -96,9 +96,13 @@ std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string & bool flag = true; // data type matching check of all input parameters of kernel for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) { - if (marjor_sm < MINIUM_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { - MS_LOG(EXCEPTION) << "Half precision op can be used on Devices which compute capacity is above " << MINIUM_SM - << ", but your device's compute capacity is " << marjor_sm; + if (marjor_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { + if (marjor_sm < MINIUM_SM) { + MS_LOG(EXCEPTION) << "Half precision ops can be used on Devices which computing capacity is >= " << MINIUM_SM + << ", but the current device's computing capacity is " << marjor_sm; + } + MS_LOG(WARNING) << "It is recommended to use devices with a computing capacity >= " << RECOMMEND_SM + << ", but the current device's computing capacity is " << marjor_sm; } if (kernel_info->GetInputDeviceType(input_index) != (iter->second)[attr_index].first.GetInputAttr(input_index).first) {