diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index fd48f26f411025214fa9951ac333bf7ba0fc8731..cec21f40073e2f674f8d843c5dc9934524bdb395 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -431,14 +431,14 @@ inline void PyCUDAPinnedTensorSetFromArray( namespace details { template -constexpr bool IsValidDTypeToPyArray() { - return false; -} - -#define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \ - template <> \ - constexpr bool IsValidDTypeToPyArray() { \ - return true; \ +struct ValidDTypeToPyArrayChecker { + static constexpr bool kValue = false; +}; + +#define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \ + template <> \ + struct ValidDTypeToPyArrayChecker { \ + static constexpr bool kValue = true; \ } DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16); @@ -452,15 +452,16 @@ DECLARE_VALID_DTYPE_TO_PY_ARRAY(int64_t); inline std::string TensorDTypeToPyDTypeStr( framework::proto::VarType::Type type) { -#define TENSOR_DTYPE_TO_PY_DTYPE(T, proto_type) \ - if (type == proto_type) { \ - if (std::is_same::value) { \ - return "e"; \ - } else { \ - PADDLE_ENFORCE(IsValidDTypeToPyArray, \ - "This type of tensor cannot be expose to Python"); \ - return py::format_descriptor::format(); \ - } \ +#define TENSOR_DTYPE_TO_PY_DTYPE(T, proto_type) \ + if (type == proto_type) { \ + if (std::is_same::value) { \ + return "e"; \ + } else { \ + constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker::kValue; \ + PADDLE_ENFORCE(kIsValidDType, \ + "This type of tensor cannot be expose to Python"); \ + return py::format_descriptor::format(); \ + } \ } _ForEachDataType_(TENSOR_DTYPE_TO_PY_DTYPE);