From 5828101c239fa51f85eb34caeb942e2df92467e6 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 21 May 2018 15:11:17 +0800 Subject: [PATCH] make uint8 support in data_type transform and memory optimize (#10715) * "a piece of job." * "fix typeo" * "fix ci" --- paddle/fluid/framework/data_type_transform.cc | 6 ++++++ paddle/fluid/pybind/protobuf.cc | 1 + python/paddle/fluid/framework.py | 2 ++ .../fluid/transpiler/memory_optimization_transpiler.py | 3 ++- 4 files changed, 11 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/data_type_transform.cc b/paddle/fluid/framework/data_type_transform.cc index c0523f3c795..5a57ec20585 100644 --- a/paddle/fluid/framework/data_type_transform.cc +++ b/paddle/fluid/framework/data_type_transform.cc @@ -91,6 +91,12 @@ void TransDataType(const OpKernelType& kernel_type_for_var, case proto::VarType::BOOL: framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); break; + case proto::VarType::INT16: + framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); + break; + case proto::VarType::UINT8: + framework::VisitDataType(dst_type, CastDataType(in, out, ctx)); + break; default: PADDLE_THROW("Not support type %d", src_type); } diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 6471eb3ab7b..bcf6d4dd308 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -238,6 +238,7 @@ void BindVarDsec(pybind11::module *m) { pybind11::enum_(var_desc, "VarType", "") .value("BOOL", pd::proto::VarType::BOOL) + .value("UINT8", pd::proto::VarType::UINT8) .value("INT16", pd::proto::VarType::INT16) .value("INT32", pd::proto::VarType::INT32) .value("INT64", pd::proto::VarType::INT64) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 38c765938fe..161ea55586b 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -72,6 +72,8 @@ def convert_np_dtype_to_dtype_(np_dtype): return core.VarDesc.VarType.INT64 elif dtype == np.bool: return core.VarDesc.VarType.BOOL + elif dtype == np.uint8: + return core.VarDesc.VarType.UINT8 else: raise ValueError("Not supported numpy dtype " + str(dtype)) diff --git a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py index 49034b47b2d..80a8f7c09cf 100644 --- a/python/paddle/fluid/transpiler/memory_optimization_transpiler.py +++ b/python/paddle/fluid/transpiler/memory_optimization_transpiler.py @@ -24,7 +24,8 @@ dtype_to_size = { core.VarDesc.VarType.INT16: 2, core.VarDesc.VarType.INT32: 4, core.VarDesc.VarType.INT64: 8, - core.VarDesc.VarType.BOOL: 1 + core.VarDesc.VarType.BOOL: 1, + core.VarDesc.VarType.UINT8: 1, } SUB_BLOCK_OPS = [ -- GitLab