未验证 提交 5828101c 编写于 作者: D dzhwinter 提交者: GitHub

make uint8 support in data_type transform and memory optimize (#10715)

* "a piece of job."

* "fix typeo"

* "fix ci"
上级 ebefdbe3
......@@ -91,6 +91,12 @@ void TransDataType(const OpKernelType& kernel_type_for_var,
case proto::VarType::BOOL:
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
break;
case proto::VarType::INT16:
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
break;
case proto::VarType::UINT8:
framework::VisitDataType(dst_type, CastDataType<bool>(in, out, ctx));
break;
default:
PADDLE_THROW("Not support type %d", src_type);
}
......
......@@ -238,6 +238,7 @@ void BindVarDsec(pybind11::module *m) {
pybind11::enum_<pd::proto::VarType::Type>(var_desc, "VarType", "")
.value("BOOL", pd::proto::VarType::BOOL)
.value("UINT8", pd::proto::VarType::UINT8)
.value("INT16", pd::proto::VarType::INT16)
.value("INT32", pd::proto::VarType::INT32)
.value("INT64", pd::proto::VarType::INT64)
......
......@@ -72,6 +72,8 @@ def convert_np_dtype_to_dtype_(np_dtype):
return core.VarDesc.VarType.INT64
elif dtype == np.bool:
return core.VarDesc.VarType.BOOL
elif dtype == np.uint8:
return core.VarDesc.VarType.UINT8
else:
raise ValueError("Not supported numpy dtype " + str(dtype))
......
......@@ -24,7 +24,8 @@ dtype_to_size = {
core.VarDesc.VarType.INT16: 2,
core.VarDesc.VarType.INT32: 4,
core.VarDesc.VarType.INT64: 8,
core.VarDesc.VarType.BOOL: 1
core.VarDesc.VarType.BOOL: 1,
core.VarDesc.VarType.UINT8: 1,
}
SUB_BLOCK_OPS = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册