未验证 提交 6f0f45f6 编写于 作者: W Wilber 提交者: GitHub

copy_to_cpu support uint8 (#28372)

上级 09fd2b2a
...@@ -119,9 +119,12 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) { ...@@ -119,9 +119,12 @@ py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) {
case PaddleDType::FLOAT32: case PaddleDType::FLOAT32:
dt = py::dtype::of<float>(); dt = py::dtype::of<float>();
break; break;
case PaddleDType::UINT8:
dt = py::dtype::of<uint8_t>();
break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now only supports INT32, INT64 and " "Unsupported data type. Now only supports INT32, INT64, UINT8 and "
"FLOAT32.")); "FLOAT32."));
} }
...@@ -187,9 +190,12 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT ...@@ -187,9 +190,12 @@ py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) { // NOLINT
case PaddleDType::FLOAT32: case PaddleDType::FLOAT32:
tensor.copy_to_cpu<float>(static_cast<float *>(array.mutable_data())); tensor.copy_to_cpu<float>(static_cast<float *>(array.mutable_data()));
break; break;
case PaddleDType::UINT8:
tensor.copy_to_cpu<uint8_t>(static_cast<uint8_t *>(array.mutable_data()));
break;
default: default:
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type. Now only supports INT32, INT64 and " "Unsupported data type. Now only supports INT32, INT64, UINT8 and "
"FLOAT32.")); "FLOAT32."));
} }
return array; return array;
......
...@@ -112,4 +112,4 @@ if(LINUX AND NOT WITH_SW) ...@@ -112,4 +112,4 @@ if(LINUX AND NOT WITH_SW)
message(FATAL_ERROR "patchelf not found, please install it.\n" message(FATAL_ERROR "patchelf not found, please install it.\n"
"For Ubuntu, the command is: apt-get install -y patchelf.") "For Ubuntu, the command is: apt-get install -y patchelf.")
endif() endif()
endif(LINUX) endif()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册