未验证 提交 c5eeecca 编写于 作者: Z Zeng Jinle 提交者: GitHub

Fix tensor_py.h (#17195)

* fix tensor_py,test=develop

* change class name,test=develop
上级 ee2028a1
......@@ -431,14 +431,14 @@ inline void PyCUDAPinnedTensorSetFromArray(
namespace details {
template <typename T>
constexpr bool IsValidDTypeToPyArray() {
return false;
}
#define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \
template <> \
constexpr bool IsValidDTypeToPyArray<type>() { \
return true; \
struct ValidDTypeToPyArrayChecker {
static constexpr bool kValue = false;
};
#define DECLARE_VALID_DTYPE_TO_PY_ARRAY(type) \
template <> \
struct ValidDTypeToPyArrayChecker<type> { \
static constexpr bool kValue = true; \
}
DECLARE_VALID_DTYPE_TO_PY_ARRAY(platform::float16);
......@@ -452,15 +452,16 @@ DECLARE_VALID_DTYPE_TO_PY_ARRAY(int64_t);
inline std::string TensorDTypeToPyDTypeStr(
framework::proto::VarType::Type type) {
#define TENSOR_DTYPE_TO_PY_DTYPE(T, proto_type) \
if (type == proto_type) { \
if (std::is_same<T, platform::float16>::value) { \
return "e"; \
} else { \
PADDLE_ENFORCE(IsValidDTypeToPyArray<T>, \
"This type of tensor cannot be expose to Python"); \
return py::format_descriptor<T>::format(); \
} \
#define TENSOR_DTYPE_TO_PY_DTYPE(T, proto_type) \
if (type == proto_type) { \
if (std::is_same<T, platform::float16>::value) { \
return "e"; \
} else { \
constexpr auto kIsValidDType = ValidDTypeToPyArrayChecker<T>::kValue; \
PADDLE_ENFORCE(kIsValidDType, \
"This type of tensor cannot be expose to Python"); \
return py::format_descriptor<T>::format(); \
} \
}
_ForEachDataType_(TENSOR_DTYPE_TO_PY_DTYPE);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册