提交 a0d36566 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix memory leak when set tensor dtype to Quantized

GitOrigin-RevId: daa6b1912429754c79629eab69636782003797cc
上级 f01b1255
......@@ -192,12 +192,16 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(DType dty
const std::vector<std::pair<const char*, PyObject*>>& data) {
PyObject* metadata = PyDict_New();
PyObject* mgb_dtype_metadata = PyDict_New();
PyDict_SetItemString(
mgb_dtype_metadata, "name", PyUnicode_FromString(name));
PyObject* py_name = PyUnicode_FromString(name);
PyDict_SetItemString(mgb_dtype_metadata, "name", py_name);
Py_DECREF(py_name);
for (const auto& d : data) {
PyDict_SetItemString(mgb_dtype_metadata, d.first, d.second);
}
PyDict_SetItemString(metadata, "mgb_dtype", mgb_dtype_metadata);
Py_DECREF(mgb_dtype_metadata);
return metadata;
};
if (dtype.has_param()) {
......@@ -206,51 +210,62 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(DType dty
case DTypeEnum::QuantizedS1: {
auto& param = dtype.param<dtype::QuantizedS1>();
type_descr = PyArray_DescrNewFromType(NPY_INT8);
PyObject* scale = PyFloat_FromDouble(param.scale);
type_descr->metadata = build_mgb_dtype_dict(
DTypeTrait<dtype::QuantizedS1>::name,
{{"scale", PyFloat_FromDouble(param.scale)}});
DTypeTrait<dtype::QuantizedS1>::name, {{"scale", scale}});
Py_DECREF(scale);
break;
}
case DTypeEnum::Quantized4Asymm: {
auto& param = dtype.param<dtype::Quantized4Asymm>();
type_descr = PyArray_DescrNewFromType(NPY_UINT8);
PyObject* scale = PyFloat_FromDouble(param.scale);
PyObject* zero_point = PyLong_FromLong(param.zero_point);
type_descr->metadata = build_mgb_dtype_dict(
DTypeTrait<dtype::Quantized4Asymm>::name,
{{"scale", PyFloat_FromDouble(param.scale)},
{"zero_point", PyLong_FromLong(param.zero_point)}});
{{"scale", scale}, {"zero_point", zero_point}});
Py_DECREF(scale);
Py_DECREF(zero_point);
break;
}
case DTypeEnum::QuantizedS4: {
auto& param = dtype.param<dtype::QuantizedS4>();
type_descr = PyArray_DescrNewFromType(NPY_INT8);
PyObject* scale = PyFloat_FromDouble(param.scale);
type_descr->metadata = build_mgb_dtype_dict(
DTypeTrait<dtype::QuantizedS4>::name,
{{"scale", PyFloat_FromDouble(param.scale)}});
DTypeTrait<dtype::QuantizedS4>::name, {{"scale", scale}});
Py_DECREF(scale);
break;
}
case DTypeEnum::Quantized8Asymm: {
auto& param = dtype.param<dtype::Quantized8Asymm>();
type_descr = PyArray_DescrNewFromType(NPY_UINT8);
PyObject* scale = PyFloat_FromDouble(param.scale);
PyObject* zero_point = PyLong_FromLong(param.zero_point);
type_descr->metadata = build_mgb_dtype_dict(
DTypeTrait<dtype::Quantized8Asymm>::name,
{{"scale", PyFloat_FromDouble(param.scale)},
{"zero_point", PyLong_FromLong(param.zero_point)}});
{{"scale", scale}, {"zero_point", zero_point}});
Py_DECREF(scale);
Py_DECREF(zero_point);
break;
}
case DTypeEnum::QuantizedS8: {
auto& param = dtype.param<dtype::QuantizedS8>();
type_descr = PyArray_DescrNewFromType(NPY_INT8);
type_descr->metadata = build_mgb_dtype_dict(
DTypeTrait<dtype::QuantizedS8>::name,
{{"scale", PyFloat_FromDouble(param.scale)}});
PyObject* scale = PyFloat_FromDouble(param.scale);
auto metadata = build_mgb_dtype_dict(
DTypeTrait<dtype::QuantizedS8>::name, {{"scale", scale}});
type_descr->metadata = metadata;
Py_DECREF(scale);
break;
}
case DTypeEnum::QuantizedS32: {
auto& param = dtype.param<dtype::QuantizedS32>();
type_descr = PyArray_DescrNewFromType(NPY_INT32);
PyObject* scale = PyFloat_FromDouble(param.scale);
type_descr->metadata = build_mgb_dtype_dict(
DTypeTrait<dtype::QuantizedS32>::name,
{{"scale", PyFloat_FromDouble(param.scale)}});
DTypeTrait<dtype::QuantizedS32>::name, {{"scale", scale}});
Py_DECREF(scale);
break;
}
default:
......
......@@ -422,6 +422,7 @@ mgb::DType _get_dtype(py::handle tensor) {
py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) {
PyArray_Descr* descr;
py::object ret;
if (!PyArray_DescrConverter(dtype_hdl.ptr(), &descr)) {
throw py::value_error(ssprintf(
"can not convert to numpy.dtype from %s",
......@@ -432,11 +433,13 @@ py::object _astype_cpp(py::handle tensor, py::handle dtype_hdl) {
std::shared_ptr<OpDef> op = TypeCvt::make(npy::dtype_np2mgb_descr(descr));
py::object Op = py::cast(op);
PyObject* p[2] = {Op.ptr(), tensor.ptr()};
py::tuple ret = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
return ret[0];
py::tuple apply_res = py::reinterpret_steal<py::object>(py_apply(NULL, p, 2));
ret = apply_res[0];
} else {
return py::reinterpret_borrow<py::object>(tensor);
ret = py::reinterpret_borrow<py::object>(tensor);
}
Py_DECREF(descr);
return ret;
}
py::object _convert_single_value_cpp(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册