未验证 提交 674e0ce2 编写于 作者: Z Zeng Jinle 提交者: GitHub

Use Python C-API to speed up dygraph trace (#17837)

* use python api to reduce python time cost, test=develop

* fix travis ci, test=develop

* fix Py_None error,test=develop
上级 47cc1b51
......@@ -14,11 +14,14 @@ limitations under the License. */
#include "paddle/fluid/pybind/imperative.h"
#include <Python.h>
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include <memory>
#include <unordered_map>
#include <utility>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/imperative/layer.h"
......@@ -31,6 +34,8 @@ limitations under the License. */
namespace paddle {
namespace pybind {
namespace py = ::pybind11;
class Layer : public imperative::Layer {
public:
using imperative::Layer::Layer; // Inherit constructors
......@@ -51,10 +56,102 @@ class PYBIND11_HIDDEN PyOpBase : public imperative::OpBase {
PyOpBase(const std::string &name) : OpBase(name) {}
};
// Function like obj.attr_name in Python.
static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) {
// NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name
// is not inside obj, but it would also set the error flag of Python.
// If the error flag is set in C++, C++ code would not raise Exception,
// but Python would raise Exception once C++ call ends.
// To avoid unexpected Exception raised in Python, we check whether
// attribute exists before calling PyObject_GetAttrString.
//
// Caution: PyObject_GetAttrString would increase reference count of PyObject.
// Developer should call Py_DECREF manually after the attribute is not used.
if (PyObject_HasAttrString(obj, attr_name)) {
return PyObject_GetAttrString(obj, attr_name);
} else {
return nullptr;
}
}
template <typename T>
static T PyObjectCast(PyObject *obj) {
try {
return py::cast<T>(py::handle(obj));
} catch (py::cast_error &) {
PADDLE_THROW("Python object is not type of %s", typeid(T).name());
}
}
// NOTE(zjl): py::handle is a very light wrapper of PyObject *.
// Unlike py::object, py::handle does not change reference count of PyObject *.
static std::vector<std::shared_ptr<imperative::VarBase>>
GetVarBaseListFromPyHandle(const py::handle &handle) {
PyObject *py_obj = handle.ptr(); // get underlying PyObject
// Python None is not nullptr in C++!
if (!py_obj || py_obj == Py_None) {
return {};
}
const char *kIVarField = "_ivar";
PyObject *py_ivar = GetPythonAttribute(py_obj, kIVarField);
std::vector<std::shared_ptr<imperative::VarBase>> result;
if (py_ivar) { // Variable
result.emplace_back(
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
Py_DECREF(py_ivar);
} else if (PyList_Check(py_obj)) { // List of Variable
size_t len = PyList_GET_SIZE(py_obj);
result.reserve(len);
for (size_t i = 0; i < len; ++i) {
PyObject *py_ivar =
PyObject_GetAttrString(PyList_GET_ITEM(py_obj, i), kIVarField);
PADDLE_ENFORCE_NOT_NULL(py_ivar);
result.emplace_back(
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
Py_DECREF(py_ivar);
}
} else if (PyTuple_Check(py_obj)) { // Tuple of Variable
size_t len = PyTuple_GET_SIZE(py_obj);
result.reserve(len);
for (size_t i = 0; i < len; ++i) {
PyObject *py_ivar =
PyObject_GetAttrString(PyTuple_GET_ITEM(py_obj, i), kIVarField);
PADDLE_ENFORCE_NOT_NULL(py_ivar);
result.emplace_back(
PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
Py_DECREF(py_ivar);
}
} else {
PADDLE_THROW(
"unsupported type %s, must be Variable, List[Variable] or "
"tuple[Variable]",
py::str(handle));
}
PADDLE_ENFORCE(PyErr_Occurred() == nullptr,
py::str(py::handle(PyErr_Occurred())));
return result;
}
using PyVarBaseMap = std::unordered_map<std::string, py::handle>;
static imperative::VarBasePtrMap ConvertToVarBasePtrMap(
const PyVarBaseMap &map) {
imperative::VarBasePtrMap result;
for (auto &pair : map) {
auto var_vec = GetVarBaseListFromPyHandle(pair.second);
if (!var_vec.empty()) {
result.emplace(pair.first, std::move(var_vec));
}
}
return result;
}
// Bind Methods
void BindImperative(pybind11::module *m_ptr) {
namespace py = ::pybind11;
auto &m = *m_ptr;
py::class_<imperative::detail::BackwardStrategy> backward_strategy(
......@@ -145,31 +242,41 @@ void BindImperative(pybind11::module *m_ptr) {
return self.Forward(inputs);
});
py::class_<imperative::Tracer>(*m, "Tracer", "")
// NOTE(zjl): Tracer use PyVarBaseMap as its parameter but not VarBasePtrMap.
// We call Python C-API to convert PyVarBaseMap to VarBasePtrMap, instead
// making conversion in Python code. This speed up Tracer.trace() about 6%
// in ptb model and make time cost in Python to be nearly zero.
py::class_<imperative::Tracer>(m, "Tracer", "")
.def("__init__",
[](imperative::Tracer &self, framework::BlockDesc *root_block) {
new (&self) imperative::Tracer(root_block);
})
.def("trace",
[](imperative::Tracer &self, imperative::OpBase *op,
const imperative::VarBasePtrMap &inputs,
imperative::VarBasePtrMap *outputs,
const PyVarBaseMap &inputs, const PyVarBaseMap &outputs,
framework::AttributeMap attrs_map,
const platform::CPUPlace expected_place,
const bool stop_gradient = false) {
auto ins = ConvertToVarBasePtrMap(inputs);
auto outs = ConvertToVarBasePtrMap(outputs);
{
py::gil_scoped_release release;
self.Trace(op, inputs, outputs, attrs_map, expected_place,
self.Trace(op, std::move(ins), &outs, attrs_map, expected_place,
stop_gradient);
}
})
.def("trace", [](imperative::Tracer &self, imperative::OpBase *op,
const imperative::VarBasePtrMap &inputs,
imperative::VarBasePtrMap *outputs,
const PyVarBaseMap &inputs, const PyVarBaseMap &outputs,
framework::AttributeMap attrs_map,
const platform::CUDAPlace expected_place,
const bool stop_gradient = false) {
auto ins = ConvertToVarBasePtrMap(inputs);
auto outs = ConvertToVarBasePtrMap(outputs);
{
py::gil_scoped_release release;
self.Trace(op, inputs, outputs, attrs_map, expected_place,
self.Trace(op, std::move(ins), &outs, attrs_map, expected_place,
stop_gradient);
}
});
// define parallel context
......
......@@ -52,27 +52,10 @@ class Tracer(core.Tracer):
self._trace_id = 0
def trace_op(self, op, inputs, outputs, stop_gradient=False):
# TODO(hy): previous version will cause memory failed
inps = defaultdict(list)
for k, vars in six.iteritems(inputs):
if isinstance(vars, framework.Variable):
inps[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
inps[k].append(var._ivar)
outs = defaultdict(list)
for k, vars in six.iteritems(outputs):
if isinstance(vars, framework.Variable):
outs[k].append(vars._ivar)
elif isinstance(vars, list) or isinstance(vars, tuple):
for var in vars:
outs[k].append(var._ivar)
# record op's trace id
op.iop._trace_id = self._trace_id
self.trace(op.iop, inps, outs, op.attrs,
self.trace(op.iop, inputs, outputs, op.attrs,
framework._current_expected_place(), stop_gradient)
if not stop_gradient and self._train_mode:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册