diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index c0523f3c795b103c0c27081ec5dc717f6a0f11e0..5a57ec20585c26dbcd4251464718fc819148a7a5 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -91,6 +91,12 @@ void TransDataType(const OpKernelType& kernel_type_for_var, case proto::VarType::BOOL: framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; + case proto::VarType::INT16: + framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); + break; + case proto::VarType::UINT8: + framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); + break; default: PADDLE_THROW("Not support type %d", src_type); } diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 6471eb3ab7bf05365c0bb2bf68bb74ef9044c527..bcf6d4dd3087060c016e53722cde80704ef2e834 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -238,6 +238,7 @@ void BindVarDsec(pybind11::module *m) { pybind11::enum_(var_desc, "VarType", "") .value("BOOL", pd::proto::VarType::BOOL) + .value("UINT8", pd::proto::VarType::UINT8) .value("INT16", pd::proto::VarType::INT16) .value("INT32", pd::proto::VarType::INT32) .value("INT64", pd::proto::VarType::INT64) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 38c765938fe9d7b2103bfdd926874c485d0ff4dc..161ea55586bbb6bde2cbb0084bb67b184f91460e 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -72,6 +72,8 @@ def convert_np_dtype_to_dtype_(np_dtype): return core.VarDesc.VarType.INT64 elif dtype == np.bool: return core.VarDesc.VarType.BOOL + elif dtype == np.uint8: + return core.VarDesc.VarType.UINT8 else: raise ValueError("Not supported numpy dtype " + str(dtype)) diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 49034b47b2d184e4027bcebc29413a163340fdaa..80a8f7c09cfe521f8f94a27e85fc8d86c02b3e97 100644 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -24,7 +24,8 @@ dtype_to_size = { core.VarDesc.VarType.INT16: 2, core.VarDesc.VarType.INT32: 4, core.VarDesc.VarType.INT64: 8, - core.VarDesc.VarType.BOOL: 1 + core.VarDesc.VarType.BOOL: 1, + core.VarDesc.VarType.UINT8: 1, } SUB_BLOCK_OPS = [