From 674e0ce2d6ce6354dc45bf480458ff443cba6d07 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 5 Jun 2019 11:05:56 +0800 Subject: [PATCH] Use Python C-API to speed up dygraph trace (#17837) * use python api to reduce python time cost, test=develop * fix travis ci, test=develop * fix Py_None error,test=develop --- paddle/fluid/pybind/imperative.cc | 133 +++++++++++++++++++++++--- python/paddle/fluid/dygraph/tracer.py | 19 +--- 2 files changed, 121 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 31156ab1c98..c438d6edf29 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -14,11 +14,14 @@ limitations under the License. */ #include "paddle/fluid/pybind/imperative.h" +#include #include #include #include #include #include +#include +#include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/imperative/layer.h" @@ -31,6 +34,8 @@ limitations under the License. */ namespace paddle { namespace pybind { +namespace py = ::pybind11; + class Layer : public imperative::Layer { public: using imperative::Layer::Layer; // Inherit constructors @@ -51,10 +56,102 @@ class PYBIND11_HIDDEN PyOpBase : public imperative::OpBase { PyOpBase(const std::string &name) : OpBase(name) {} }; +// Function like obj.attr_name in Python. +static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) { + // NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name + // is not inside obj, but it would also set the error flag of Python. + // If the error flag is set in C++, C++ code would not raise Exception, + // but Python would raise Exception once C++ call ends. + // To avoid unexpected Exception raised in Python, we check whether + // attribute exists before calling PyObject_GetAttrString. + // + // Caution: PyObject_GetAttrString would increase reference count of PyObject. + // Developer should call Py_DECREF manually after the attribute is not used. + if (PyObject_HasAttrString(obj, attr_name)) { + return PyObject_GetAttrString(obj, attr_name); + } else { + return nullptr; + } +} + +template +static T PyObjectCast(PyObject *obj) { + try { + return py::cast(py::handle(obj)); + } catch (py::cast_error &) { + PADDLE_THROW("Python object is not type of %s", typeid(T).name()); + } +} + +// NOTE(zjl): py::handle is a very light wrapper of PyObject *. +// Unlike py::object, py::handle does not change reference count of PyObject *. +static std::vector> +GetVarBaseListFromPyHandle(const py::handle &handle) { + PyObject *py_obj = handle.ptr(); // get underlying PyObject + // Python None is not nullptr in C++! + if (!py_obj || py_obj == Py_None) { + return {}; + } + + const char *kIVarField = "_ivar"; + PyObject *py_ivar = GetPythonAttribute(py_obj, kIVarField); + std::vector> result; + + if (py_ivar) { // Variable + result.emplace_back( + PyObjectCast>(py_ivar)); + Py_DECREF(py_ivar); + } else if (PyList_Check(py_obj)) { // List of Variable + size_t len = PyList_GET_SIZE(py_obj); + result.reserve(len); + for (size_t i = 0; i < len; ++i) { + PyObject *py_ivar = + PyObject_GetAttrString(PyList_GET_ITEM(py_obj, i), kIVarField); + PADDLE_ENFORCE_NOT_NULL(py_ivar); + result.emplace_back( + PyObjectCast>(py_ivar)); + Py_DECREF(py_ivar); + } + } else if (PyTuple_Check(py_obj)) { // Tuple of Variable + size_t len = PyTuple_GET_SIZE(py_obj); + result.reserve(len); + for (size_t i = 0; i < len; ++i) { + PyObject *py_ivar = + PyObject_GetAttrString(PyTuple_GET_ITEM(py_obj, i), kIVarField); + PADDLE_ENFORCE_NOT_NULL(py_ivar); + result.emplace_back( + PyObjectCast>(py_ivar)); + Py_DECREF(py_ivar); + } + } else { + PADDLE_THROW( + "unsupported type %s, must be Variable, List[Variable] or " + "tuple[Variable]", + py::str(handle)); + } + + PADDLE_ENFORCE(PyErr_Occurred() == nullptr, + py::str(py::handle(PyErr_Occurred()))); + + return result; +} + +using PyVarBaseMap = std::unordered_map; + +static imperative::VarBasePtrMap ConvertToVarBasePtrMap( + const PyVarBaseMap &map) { + imperative::VarBasePtrMap result; + for (auto &pair : map) { + auto var_vec = GetVarBaseListFromPyHandle(pair.second); + if (!var_vec.empty()) { + result.emplace(pair.first, std::move(var_vec)); + } + } + return result; +} + // Bind Methods void BindImperative(pybind11::module *m_ptr) { - namespace py = ::pybind11; - auto &m = *m_ptr; py::class_ backward_strategy( @@ -145,31 +242,41 @@ void BindImperative(pybind11::module *m_ptr) { return self.Forward(inputs); }); - py::class_(*m, "Tracer", "") + // NOTE(zjl): Tracer use PyVarBaseMap as its parameter but not VarBasePtrMap. + // We call Python C-API to convert PyVarBaseMap to VarBasePtrMap, instead + // making conversion in Python code. This speed up Tracer.trace() about 6% + // in ptb model and make time cost in Python to be nearly zero. + py::class_(m, "Tracer", "") .def("__init__", [](imperative::Tracer &self, framework::BlockDesc *root_block) { new (&self) imperative::Tracer(root_block); }) .def("trace", [](imperative::Tracer &self, imperative::OpBase *op, - const imperative::VarBasePtrMap &inputs, - imperative::VarBasePtrMap *outputs, + const PyVarBaseMap &inputs, const PyVarBaseMap &outputs, framework::AttributeMap attrs_map, const platform::CPUPlace expected_place, const bool stop_gradient = false) { - py::gil_scoped_release release; - self.Trace(op, inputs, outputs, attrs_map, expected_place, - stop_gradient); + auto ins = ConvertToVarBasePtrMap(inputs); + auto outs = ConvertToVarBasePtrMap(outputs); + { + py::gil_scoped_release release; + self.Trace(op, std::move(ins), &outs, attrs_map, expected_place, + stop_gradient); + } }) .def("trace", [](imperative::Tracer &self, imperative::OpBase *op, - const imperative::VarBasePtrMap &inputs, - imperative::VarBasePtrMap *outputs, + const PyVarBaseMap &inputs, const PyVarBaseMap &outputs, framework::AttributeMap attrs_map, const platform::CUDAPlace expected_place, const bool stop_gradient = false) { - py::gil_scoped_release release; - self.Trace(op, inputs, outputs, attrs_map, expected_place, - stop_gradient); + auto ins = ConvertToVarBasePtrMap(inputs); + auto outs = ConvertToVarBasePtrMap(outputs); + { + py::gil_scoped_release release; + self.Trace(op, std::move(ins), &outs, attrs_map, expected_place, + stop_gradient); + } }); // define parallel context diff --git a/python/paddle/fluid/dygraph/tracer.py b/python/paddle/fluid/dygraph/tracer.py index c802e31115e..aea95f2f530 100644 --- a/python/paddle/fluid/dygraph/tracer.py +++ b/python/paddle/fluid/dygraph/tracer.py @@ -52,27 +52,10 @@ class Tracer(core.Tracer): self._trace_id = 0 def trace_op(self, op, inputs, outputs, stop_gradient=False): - # TODO(hy): previous version will cause memory failed - inps = defaultdict(list) - for k, vars in six.iteritems(inputs): - if isinstance(vars, framework.Variable): - inps[k].append(vars._ivar) - elif isinstance(vars, list) or isinstance(vars, tuple): - for var in vars: - inps[k].append(var._ivar) - - outs = defaultdict(list) - for k, vars in six.iteritems(outputs): - if isinstance(vars, framework.Variable): - outs[k].append(vars._ivar) - elif isinstance(vars, list) or isinstance(vars, tuple): - for var in vars: - outs[k].append(var._ivar) - # record op's trace id op.iop._trace_id = self._trace_id - self.trace(op.iop, inps, outs, op.attrs, + self.trace(op.iop, inputs, outputs, op.attrs, framework._current_expected_place(), stop_gradient) if not stop_gradient and self._train_mode: -- GitLab