未验证 提交 c95cd474 编写于 作者: F fengjiayi 提交者: GitHub

Merge pull request #10975 from JiayiFeng/fix_bug_in_uint8_support

Correct uint8 support
...@@ -89,4 +89,5 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>, ...@@ -89,4 +89,5 @@ REGISTER_OP_CPU_KERNEL(cast, ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, int>, ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int64_t>, ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, bool>, ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>); ops::CastOpKernel<CPU, paddle::platform::float16>);
...@@ -21,5 +21,5 @@ using CastOpKernel = ...@@ -21,5 +21,5 @@ using CastOpKernel =
REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>, REGISTER_OP_CUDA_KERNEL(cast, CastOpKernel<float>, CastOpKernel<double>,
CastOpKernel<int>, CastOpKernel<int64_t>, CastOpKernel<int>, CastOpKernel<int64_t>,
CastOpKernel<bool>, CastOpKernel<bool>, CastOpKernel<uint8_t>,
CastOpKernel<paddle::platform::float16>); CastOpKernel<paddle::platform::float16>);
...@@ -117,6 +117,7 @@ PYBIND11_PLUGIN(core) { ...@@ -117,6 +117,7 @@ PYBIND11_PLUGIN(core) {
.def("set", PyCPUTensorSetFromArray<int64_t>) .def("set", PyCPUTensorSetFromArray<int64_t>)
.def("set", PyCPUTensorSetFromArray<bool>) .def("set", PyCPUTensorSetFromArray<bool>)
.def("set", PyCPUTensorSetFromArray<uint16_t>) .def("set", PyCPUTensorSetFromArray<uint16_t>)
.def("set", PyCPUTensorSetFromArray<uint8_t>)
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
.def("set", PyCUDATensorSetFromArray<float>) .def("set", PyCUDATensorSetFromArray<float>)
.def("set", PyCUDATensorSetFromArray<int>) .def("set", PyCUDATensorSetFromArray<int>)
...@@ -124,12 +125,14 @@ PYBIND11_PLUGIN(core) { ...@@ -124,12 +125,14 @@ PYBIND11_PLUGIN(core) {
.def("set", PyCUDATensorSetFromArray<int64_t>) .def("set", PyCUDATensorSetFromArray<int64_t>)
.def("set", PyCUDATensorSetFromArray<bool>) .def("set", PyCUDATensorSetFromArray<bool>)
.def("set", PyCUDATensorSetFromArray<uint16_t>) .def("set", PyCUDATensorSetFromArray<uint16_t>)
.def("set", PyCUDATensorSetFromArray<uint8_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<float>) .def("set", PyCUDAPinnedTensorSetFromArray<float>)
.def("set", PyCUDAPinnedTensorSetFromArray<int>) .def("set", PyCUDAPinnedTensorSetFromArray<int>)
.def("set", PyCUDAPinnedTensorSetFromArray<double>) .def("set", PyCUDAPinnedTensorSetFromArray<double>)
.def("set", PyCUDAPinnedTensorSetFromArray<int64_t>) .def("set", PyCUDAPinnedTensorSetFromArray<int64_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<bool>) .def("set", PyCUDAPinnedTensorSetFromArray<bool>)
.def("set", PyCUDAPinnedTensorSetFromArray<uint16_t>) .def("set", PyCUDAPinnedTensorSetFromArray<uint16_t>)
.def("set", PyCUDAPinnedTensorSetFromArray<uint8_t>)
#endif #endif
.def("shape", [](Tensor &self) { return vectorize(self.dims()); }) .def("shape", [](Tensor &self) { return vectorize(self.dims()); })
.def("set_float_element", TensorSetElement<float>) .def("set_float_element", TensorSetElement<float>)
......
...@@ -36,9 +36,11 @@ class DataToLoDTensorConverter(object): ...@@ -36,9 +36,11 @@ class DataToLoDTensorConverter(object):
self.dtype = 'float64' self.dtype = 'float64'
elif dtype == core.VarDesc.VarType.INT32: elif dtype == core.VarDesc.VarType.INT32:
self.dtype = 'int32' self.dtype = 'int32'
elif dtype == core.VarDesc.VarType.UINT8:
self.dtype = 'uint8'
else: else:
raise ValueError("dtype must be any of [int32, float32, int64, " raise ValueError("dtype must be any of [int32, float32, int64, "
"float64]") "float64, uint8]")
self.data = [] self.data = []
self.lod = [] self.lod = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册