未验证 提交 c5eeecca 编写于 作者: Z Zeng Jinle 提交者: GitHub

Fix tensor_py.h (#17195)

* fix tensor_py,test=develop

* change class name,test=develop
上级 ee2028a1
...@@ -431,14 +431,14 @@ inline void PyCUDAPinnedTensorSetFromArray( ...@@ -431,14 +431,14 @@ inline void PyCUDAPinnedTensorSetFromArray(
namespace details { namespace details {
template <typename T> template <typename T>
constexpr bool IsValidDTypeToPyArray() { struct ValidDTypeToPyArrayChecker {
return false; static constexpr bool kValue = false;
} };
#define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \ #define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \
template <> \ template <> \
constexpr bool IsValidDTypeToPyArray<type>() { \ struct ValidDTypeToPyArrayChecker<type> { \
return true; \ static constexpr bool kValue = true; \
} }
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16); DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16);
...@@ -457,7 +457,8 @@ inline std::string TensorDTypeToPyDTypeStr( ...@@ -457,7 +457,8 @@ inline std::string TensorDTypeToPyDTypeStr(
if (std::is_same<T, platform::float16>::value) { \ if (std::is_same<T, platform::float16>::value) { \
return "e"; \ return "e"; \
} else { \ } else { \
PADDLE_ENFORCE(IsValidDTypeToPyArray<T>, \ constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker<T>::kValue; \
PADDLE_ENFORCE(kIsValidDType, \
"This type of tensor cannot be expose to Python"); \ "This type of tensor cannot be expose to Python"); \
return py::format_descriptor<T>::format(); \ return py::format_descriptor<T>::format(); \
} \ } \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册