From 85f8dd1c774dc86ca9aaa2b51edf748fbe095665 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sat, 15 Sep 2018 04:29:08 +0800 Subject: [PATCH] debug version --- cmake/cuda.cmake | 2 +- cmake/flags.cmake | 16 ++++-- cmake/generic.cmake | 5 +- paddle/fluid/framework/operator.cc | 6 +++ paddle/fluid/framework/operator.h | 2 + .../inference/api/demo_ci/CMakeLists.txt | 4 +- .../inference/api/demo_ci/inference_icnet.cc | 6 +++ .../fluid/inference/api/demo_ci/vis_demo.cc | 43 +++++++++++++++- paddle/fluid/operators/conv_op.cc | 49 ++++++++++++++----- paddle/fluid/platform/cudnn_helper_test.cc | 3 ++ paddle/fluid/platform/device_context.cc | 22 +++++++-- paddle/fluid/platform/device_context_test.cu | 9 ++++ paddle/fluid/platform/dynload/cublas.h | 2 +- paddle/fluid/platform/dynload/cudnn.h | 6 ++- paddle/fluid/platform/dynload/curand.h | 2 +- paddle/fluid/platform/enforce.h | 11 ++++- 16 files changed, 158 insertions(+), 30 deletions(-) diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index c7cd5e780..ec1461524 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -172,7 +172,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF) if (NOT WIN32) # windows msvc2015 support c++11 natively. # -std=c++11 -fPIC not recoginize by msvc -list(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC") +list(APPEND CUDA_NVCC_FLAGS "-w" "-Xcompiler -fPIC" "-Xcompiler /w") endif(NOT WIN32) list(APPEND CUDA_NVCC_FLAGS "--use_fast_math") diff --git a/cmake/flags.cmake b/cmake/flags.cmake index cf0ca71d1..30757c959 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -150,7 +150,7 @@ set(COMMON_FLAGS "/w") #disable all warnings. set(GPU_COMMON_FLAGS - "") #disable all warnings + "/w") #disable all warnings endif(NOT WIN32) @@ -177,12 +177,22 @@ endif(UNIX AND NOT APPLE) foreach(flag ${COMMON_FLAGS}) safe_set_cflag(CMAKE_C_FLAGS ${flag}) safe_set_cxxflag(CMAKE_CXX_FLAGS ${flag}) + endforeach() foreach(flag ${GPU_COMMON_FLAGS}) safe_set_nvflag(${flag}) endforeach() -if(MSVC) +if(WIN32) safe_set_static_flag() -endif(MSVC) \ No newline at end of file + foreach(flag_var + CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE + CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO + CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE + CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO) + if(${flag_var} MATCHES "/W3") + string(REGEX REPLACE "/W3" "/w" ${flag_var} "${${flag_var}}") + endif(${flag_var} MATCHES "/W3") + endforeach(flag_var) +endif(WIN32) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 7d542114f..0bb01a61b 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -243,6 +243,7 @@ function(cc_library TARGET_NAME) # add libxxx.lib prefix in windows set(${TARGET_NAME}_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}${TARGET_NAME}${CMAKE_STATIC_LIBRARY_SUFFIX}" CACHE STRING "output library name for target ${TARGET_NAME}") endif(WIN32) + message("flags" ${CMAKE_CXX_FLAGS}) if(cc_library_SRCS) if(cc_library_SHARED OR cc_library_shared) # build *.so add_library(${TARGET_NAME} SHARED ${cc_library_SRCS}) @@ -305,7 +306,7 @@ function(cc_test TARGET_NAME) set(multiValueArgs SRCS DEPS ARGS) cmake_parse_arguments(cc_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) add_executable(${TARGET_NAME} ${cc_test_SRCS}) - target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) + target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog shlwapi openblas) add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_test(NAME ${TARGET_NAME} COMMAND ${TARGET_NAME} ${cc_test_ARGS} @@ -375,7 +376,7 @@ function(nv_test TARGET_NAME) set(multiValueArgs SRCS DEPS) cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS}) - target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) + target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog shlwapi) add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog) add_test(${TARGET_NAME} ${TARGET_NAME}) if (nv_test_SERIAL) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 73306912c..a5168245a 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -149,8 +149,10 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) { platform::SetDeviceId(dev_id); #endif } + VLOG(3) << "start pool"; platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::RecordEvent record_event(Type(), pool.Get(place)); + VLOG(3) << "start RunImpl"; RunImpl(scope, place); VLOG(3) << place << " " << DebugStringEx(&scope); } @@ -660,12 +662,16 @@ static void CheckTensorNANOrInf(const std::string& name, void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); + VLOG(3) << "start Infershape"; this->InferShape(&infer_shape_ctx); + VLOG(3) << "Infershape Pass"; platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); // check if op[type] has kernel registered. + VLOG(3) << "Start Kernels"; auto& all_op_kernels = AllOpKernels(); + VLOG(3) << "Kernel map finish"; auto kernels_iter = all_op_kernels.find(type_); if (kernels_iter == all_op_kernels.end()) { PADDLE_THROW( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 1040eb882..626b50edf 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -20,6 +20,8 @@ limitations under the License. */ #include #include #include +#define GLOG_NO_ABBREVIATED_SEVERITIES +#define GOOGLE_GLOG_DLL_DECL #include "glog/logging.h" // For VLOG #include "paddle/fluid/framework/attribute.h" diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index 573f38111..d4e6bb3e4 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -114,7 +114,9 @@ if(WITH_GPU) if(NOT WIN32) set(DEPS ${DEPS} ${CUDA_LIB}/libcudart${CMAKE_SHARED_LIBRARY_SUFFIX}) else() - set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX}) + set(DEPS ${DEPS} ${CUDA_LIB}/cudart${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDA_LIB}/cublas${CMAKE_STATIC_LIBRARY_SUFFIX} ) + set(DEPS ${DEPS} ${CUDA_LIB}/cudnn${CMAKE_STATIC_LIBRARY_SUFFIX} ) endif() endif() diff --git a/paddle/fluid/inference/api/demo_ci/inference_icnet.cc b/paddle/fluid/inference/api/demo_ci/inference_icnet.cc index 4a048684b..e6040fb33 100644 --- a/paddle/fluid/inference/api/demo_ci/inference_icnet.cc +++ b/paddle/fluid/inference/api/demo_ci/inference_icnet.cc @@ -186,7 +186,12 @@ void Main(bool use_gpu) { std::cout << "begin to process data" << std::endl; // Just a single batch of data. std::string line; + std::cout << "data : " << std::endl; std::ifstream file(DATA); + if(!file.is_open()) { + std::cout << "failed open data" << DATA << std::endl; + exit(0); + } std::getline(file, line); auto record = ProcessALine(line); file.close(); @@ -207,6 +212,7 @@ void Main(bool use_gpu) { std::cout << "output: " << SummaryTensor(tensor) << std::endl; // compare with reference result + std::cout << "refer result : " << REFER << std::endl; CheckOutput(REFER, tensor); } diff --git a/paddle/fluid/inference/api/demo_ci/vis_demo.cc b/paddle/fluid/inference/api/demo_ci/vis_demo.cc index 3800d49b3..d57fb77cb 100644 --- a/paddle/fluid/inference/api/demo_ci/vis_demo.cc +++ b/paddle/fluid/inference/api/demo_ci/vis_demo.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include // use glog instead of PADDLE_ENFORCE to avoid importing other paddle header files. #include #include -#include "paddle/fluid/inference/demo_ci/utils.h" +//#include "paddle/fluid/inference/demo_ci/utils.h" #include "paddle/fluid/platform/enforce.h" #ifdef PADDLE_WITH_CUDA @@ -36,6 +36,47 @@ DEFINE_bool(use_gpu, false, "Whether use gpu."); namespace paddle { namespace demo { +static void split(const std::string& str, char sep, + std::vector* pieces) { + pieces->clear(); + if (str.empty()) { + return; + } + size_t pos = 0; + size_t next = str.find(sep, pos); + while (next != std::string::npos) { + pieces->push_back(str.substr(pos, next - pos)); + pos = next + 1; + next = str.find(sep, pos); + } + if (!str.substr(pos).empty()) { + pieces->push_back(str.substr(pos)); + } +} + +/* + * Get a summary of a PaddleTensor content. + */ +static std::string SummaryTensor(const PaddleTensor& tensor) { + std::stringstream ss; + int num_elems = tensor.data.length() / PaddleDtypeSize(tensor.dtype); + + ss << "data[:10]\t"; + switch (tensor.dtype) { + case PaddleDType::INT64: { + for (int i = 0; i < std::min(num_elems, 10); i++) { + ss << static_cast(tensor.data.data())[i] << " "; + } + break; + } + case PaddleDType::FLOAT32: + for (int i = 0; i < std::min(num_elems, 10); i++) { + ss << static_cast(tensor.data.data())[i] << " "; + } + break; + } + return ss.str(); +} struct Record { std::vector data; diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 61ca80877..e08bcea48 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -11,6 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#define GLOG_NO_ABBREVIATED_SEVERITIES +#define GOOGLE_GLOG_DLL_DECL +#include #include "paddle/fluid/operators/conv_op.h" @@ -35,6 +38,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { PADDLE_ENFORCE(ctx->HasOutput("Output"), "Output(Output) of ConvOp should not be null."); + VLOG(3) << "Conv op infershape"; auto in_dims = ctx->GetInputDim("Input"); auto filter_dims = ctx->GetInputDim("Filter"); @@ -42,32 +46,51 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { std::vector paddings = ctx->Attrs().Get>("paddings"); int groups = ctx->Attrs().Get("groups"); std::vector dilations = ctx->Attrs().Get>("dilations"); - - PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, - "Conv intput should be 4-D or 5-D tensor."); - PADDLE_ENFORCE_EQ( - in_dims.size(), filter_dims.size(), - "Conv input dimension and filter dimension should be the same."); + VLOG(3) << "Conv op Before check"; + in_dims.size() == 4 || in_dims.size() == 5; + //PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, + // "Conv intput should be 4-D or 5-D tensor."); + VLOG(3) << "check0"; + + //PADDLE_ENFORCE_EQ( + // in_dims.size(), filter_dims.size(), + // "Conv input dimension and filter dimension should be the same."); + in_dims.size() == filter_dims.size(); + VLOG(3) << "enforce check0"; PADDLE_ENFORCE( in_dims.size() - strides.size() == 2U, "Conv input dimension and strides dimension should be consistent."); + VLOG(3) << "check1"; PADDLE_ENFORCE_EQ( paddings.size(), strides.size(), "Conv paddings dimension and Conv strides dimension should be the same."); - - PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups, - "The number of input channels should be equal to filter " - "channels * groups."); - PADDLE_ENFORCE_EQ( - filter_dims[0] % groups, 0, - "The number of output channels should be divided by groups."); + + VLOG(3) << "check2"; + //in_dims[1] == filter_dims[1] * groups; + //PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups, + // "The number of input channels should be equal to filter " + // "channels * groups."); + VLOG(3) << "check3"; + //filter_dims[0] % groups == 0 ; + //PADDLE_ENFORCE_EQ( + // filter_dims[0] % groups, 0, + // "The number of output channels should be divided by groups."); + VLOG(3) << "filter" << filter_dims.size(); + VLOG(3) << "filter" << filter_dims[0]; + VLOG(3) << "check4"; + VLOG(3) << "filter" << filter_dims[1]; + VLOG(3) << "dims" << in_dims[0]; std::vector output_shape({in_dims[0], filter_dims[0]}); + VLOG(3) << "output shape"; for (size_t i = 0; i < strides.size(); ++i) { + VLOG(3) << "check5"; output_shape.push_back(ConvOutputSize(in_dims[i + 2], filter_dims[i + 2], dilations[i], paddings[i], strides[i])); + VLOG(3) << "check pass"; } + VLOG(3) << "Conv InferShape Pass"; ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->ShareLoD("Input", "Output"); } diff --git a/paddle/fluid/platform/cudnn_helper_test.cc b/paddle/fluid/platform/cudnn_helper_test.cc index 517df6863..28edfd2e5 100644 --- a/paddle/fluid/platform/cudnn_helper_test.cc +++ b/paddle/fluid/platform/cudnn_helper_test.cc @@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#define GLOG_NO_ABBREVIATED_SEVERITIES +#define GOOGLE_GLOG_DLL_DECL + #include "paddle/fluid/platform/cudnn_helper.h" #include diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 2cc26da01..476611b7d 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -40,18 +40,20 @@ DeviceContextPool::DeviceContextPool( for (auto& p : places) { set.insert(p); } - +VLOG(3) << "pool start"; for (auto& p : set) { if (platform::is_cpu_place(p)) { #ifdef PADDLE_WITH_MKLDNN device_contexts_.emplace( p, PtrType(new MKLDNNDeviceContext(boost::get(p)))); #else +VLOG(3) << "cpu context start"; device_contexts_.emplace( p, PtrType(new CPUDeviceContext(boost::get(p)))); #endif } else if (platform::is_gpu_place(p)) { #ifdef PADDLE_WITH_CUDA +VLOG(3) << "gpu context start"; device_contexts_.emplace( p, PtrType(new CUDADeviceContext(boost::get(p)))); #else @@ -61,6 +63,7 @@ DeviceContextPool::DeviceContextPool( #endif } else if (platform::is_cuda_pinned_place(p)) { #ifdef PADDLE_WITH_CUDA +VLOG(3) << "gpu pin start"; device_contexts_.emplace( p, PtrType(new CUDAPinnedDeviceContext(boost::get(p)))); @@ -70,6 +73,7 @@ DeviceContextPool::DeviceContextPool( "option"); #endif } +VLOG(3) << "pool finish"; } } @@ -147,18 +151,28 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { compute_capability = GetCUDAComputeCapability(place_.device); multi_process = GetCUDAMultiProcessors(place_.device); max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device); + VLOG(3) << "cuda info pass"; PADDLE_ENFORCE(cudaStreamCreate(&stream_)); + VLOG(3) << "cuda stream pass"; eigen_stream_.reset(new EigenCudaStreamDevice()); eigen_stream_->Reinitialize(&stream_, place); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); - PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); - PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); - if (dynload::HasCUDNN()) { + + VLOG(3) << "eigen pass"; + if (dynload::HasCUDNN()) { + VLOG(3) << "cudnn start"; PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); + VLOG(3) << "cudnn create pass"; PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); } else { cudnn_handle_ = nullptr; } + VLOG(3) << "cudnn pass"; + PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); + VLOG(3) << "cublas pass"; + PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); + VLOG(3) << "cublas pass"; + } CUDADeviceContext::~CUDADeviceContext() { diff --git a/paddle/fluid/platform/device_context_test.cu b/paddle/fluid/platform/device_context_test.cu index 171d2979a..3cac9aa1e 100644 --- a/paddle/fluid/platform/device_context_test.cu +++ b/paddle/fluid/platform/device_context_test.cu @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/device_context.h" +#include #include #include "glog/logging.h" @@ -23,6 +24,7 @@ TEST(Device, Init) { using paddle::platform::CUDADeviceContext; using paddle::platform::CUDAPlace; + VLOG(3) << "before Init"; int count = paddle::platform::GetCUDADeviceCount(); for (int i = 0; i < count; i++) { CUDADeviceContext* device_context = new CUDADeviceContext(CUDAPlace(i)); @@ -30,20 +32,25 @@ TEST(Device, Init) { ASSERT_NE(nullptr, gpu_device); delete device_context; } + VLOG(3) << "eigen pass"; } TEST(Device, CUDADeviceContext) { using paddle::platform::CUDADeviceContext; using paddle::platform::CUDAPlace; + VLOG(3) << "cudnn start"; int count = paddle::platform::GetCUDADeviceCount(); for (int i = 0; i < count; i++) { CUDADeviceContext* device_context = new CUDADeviceContext(CUDAPlace(i)); + VLOG(3) << "device context start"; Eigen::GpuDevice* gpu_device = device_context->eigen_device(); ASSERT_NE(nullptr, gpu_device); cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); + VLOG(3) << "cudnn pass"; ASSERT_NE(nullptr, cudnn_handle); cublasHandle_t cublas_handle = device_context->cublas_handle(); + VLOG(3) << "cublas pass"; ASSERT_NE(nullptr, cublas_handle); ASSERT_NE(nullptr, device_context->stream()); delete device_context; @@ -57,7 +64,9 @@ TEST(Device, DeviceContextPool) { using paddle::platform::CPUPlace; using paddle::platform::CUDAPlace; + VLOG(3) << "before instance"; DeviceContextPool& pool = DeviceContextPool::Instance(); + VLOG(3) << "after instance"; auto cpu_dev_ctx1 = pool.Get(CPUPlace()); auto cpu_dev_ctx2 = pool.Get(CPUPlace()); ASSERT_EQ(cpu_dev_ctx2, cpu_dev_ctx1); diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index c7c533bd4..2f92c2cab 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -55,7 +55,7 @@ extern void *cublas_dso_handle; struct DynLoad__##__name { \ template \ inline cublasStatus_t operator()(Args... args) { \ - return __name(args...); \ + return ::__name(args...); \ } \ }; \ extern DynLoad__##__name __name diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index 1de587bca..fdc712ca3 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#define GLOG_NO_ABBREVIATED_SEVERITIES +#define GOOGLE_GLOG_DLL_DECL +#include #include #include // NOLINT @@ -51,7 +54,8 @@ extern void EnforceCUDNNLoaded(const char* fn_name); struct DynLoad__##__name { \ template \ inline cudnnStatus_t operator()(Args... args) { \ - return __name(args...); \ + VLOG(3) << "cudnn call"; \ + return ::__name(args...); \ } \ }; \ extern DynLoad__##__name __name diff --git a/paddle/fluid/platform/dynload/curand.h b/paddle/fluid/platform/dynload/curand.h index 2daf1b421..ef2c765c8 100644 --- a/paddle/fluid/platform/dynload/curand.h +++ b/paddle/fluid/platform/dynload/curand.h @@ -44,7 +44,7 @@ extern void *curand_dso_handle; struct DynLoad__##__name { \ template \ curandStatus_t operator()(Args... args) { \ - return __name(args...); \ + return ::__name(args...); \ } \ }; \ extern DynLoad__##__name __name diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index cc24e84d5..baa123fd0 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -295,7 +295,7 @@ inline void throw_on_error(T e) { * extra messages is also supported, for example: * PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2) */ - +#if !defined(_WIN32) #define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) \ __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, ==, !=, __VA_ARGS__) #define PADDLE_ENFORCE_NE(__VAL0, __VAL1, ...) \ @@ -309,7 +309,7 @@ inline void throw_on_error(T e) { #define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \ __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__) -#if !defined(_WIN32) + #define PADDLE_ENFORCE_NOT_NULL(__VAL, ...) \ do { \ if (UNLIKELY(nullptr == (__VAL))) { \ @@ -330,6 +330,13 @@ inline void throw_on_error(T e) { } \ } while (0) #else +#define PADDLE_ENFORCE_EQ(__VAL0, __VAL1, ...) ((__VAL0)==(__VAL1)) +#define PADDLE_ENFORCE_NE(__VAL0, __VAL1, ...) ((__VAL0)!=(__VAL1)) +#define PADDLE_ENFORCE_GT(__VAL0, __VAL1, ...) ((__VAL0)>(__VAL1)) +#define PADDLE_ENFORCE_GE(__VAL0, __VAL1, ...) ((__VAL0)>=(__VAL1)) +#define PADDLE_ENFORCE_LT(__VAL0, __VAL1, ...) ((__VAL0)<(__VAL1)) +#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) ((__VAL0)<=(__VAL1)) + #define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \ do { \ if (!((__VAL0)__CMP(__VAL1))) { \ -- GitLab