/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/pybind/imperative.h" #include #include #include #include #include #include #include #include #include #include #include "paddle/fluid/imperative/backward_strategy.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/nccl_context.h" #include "paddle/fluid/imperative/profiler.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/pybind/pybind_boost_headers.h" namespace paddle { namespace pybind { namespace py = ::pybind11; class Layer : public imperative::Layer { public: using imperative::Layer::Layer; // Inherit constructors std::vector> Forward( const std::vector> &inputs) override { PYBIND11_OVERLOAD(std::vector>, Layer, Forward, inputs); // NOLINT } }; // warper for pyobject to avoid imperative module depend on python // TODO(jiabin) Add OpBase's pybind interface back to enable backward hook class PYBIND11_HIDDEN PyCallableObject { public: PyCallableObject(std::shared_ptr py_obj_ptr) : py_obj_ptr_(std::move(py_obj_ptr)) {} ~PyCallableObject() { py::call_guard(); py_obj_ptr_.reset(); } void operator()() { py::call_guard(); py_obj_ptr_->operator()(this); } private: std::shared_ptr py_obj_ptr_; }; // Function like obj.attr_name in Python. static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) { // NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name // is not inside obj, but it would also set the error flag of Python. // If the error flag is set in C++, C++ code would not raise Exception, // but Python would raise Exception once C++ call ends. // To avoid unexpected Exception raised in Python, we check whether // attribute exists before calling PyObject_GetAttrString. // // Caution: PyObject_GetAttrString would increase reference count of PyObject. // Developer should call Py_DECREF manually after the attribute is not used. if (PyObject_HasAttrString(obj, attr_name)) { return PyObject_GetAttrString(obj, attr_name); } else { return nullptr; } } template static T PyObjectCast(PyObject *obj) { try { return py::cast(py::handle(obj)); } catch (py::cast_error &) { PADDLE_THROW("Python object is not type of %s", typeid(T).name()); } } // NOTE(zjl): py::handle is a very light wrapper of PyObject *. // Unlike py::object, py::handle does not change reference count of PyObject *. static std::vector> GetVarBaseListFromPyHandle(const py::handle &handle) { PyObject *py_obj = handle.ptr(); // get underlying PyObject // Python None is not nullptr in C++! if (!py_obj || py_obj == Py_None) { return {}; } const char *kIVarField = "_ivar"; PyObject *py_ivar = GetPythonAttribute(py_obj, kIVarField); std::vector> result; if (py_ivar) { // Variable result.emplace_back( PyObjectCast>(py_ivar)); Py_DECREF(py_ivar); } else if (PyList_Check(py_obj)) { // List of Variable size_t len = PyList_GET_SIZE(py_obj); result.reserve(len); for (size_t i = 0; i < len; ++i) { PyObject *py_ivar = PyObject_GetAttrString(PyList_GET_ITEM(py_obj, i), kIVarField); PADDLE_ENFORCE_NOT_NULL(py_ivar); result.emplace_back( PyObjectCast>(py_ivar)); Py_DECREF(py_ivar); } } else if (PyTuple_Check(py_obj)) { // Tuple of Variable size_t len = PyTuple_GET_SIZE(py_obj); result.reserve(len); for (size_t i = 0; i < len; ++i) { PyObject *py_ivar = PyObject_GetAttrString(PyTuple_GET_ITEM(py_obj, i), kIVarField); PADDLE_ENFORCE_NOT_NULL(py_ivar); result.emplace_back( PyObjectCast>(py_ivar)); Py_DECREF(py_ivar); } } else { PADDLE_THROW( "unsupported type %s, must be Variable, list[Variable] or " "tuple[Variable]", py::str(handle)); } return result; } using PyNameVarBaseMap = std::unordered_map; static imperative::NameVarBaseMap ConvertToNameVarBaseMap( const PyNameVarBaseMap &map) { imperative::NameVarBaseMap result; for (auto &pair : map) { auto var_vec = GetVarBaseListFromPyHandle(pair.second); if (!var_vec.empty()) { result.emplace(pair.first, std::move(var_vec)); } } PADDLE_ENFORCE_EQ(PyErr_Occurred() == nullptr, true, py::str(py::handle(PyErr_Occurred()))); return result; } static std::string GetTypeName(const imperative::VarBase &var) { if (var.Type() == framework::proto::VarType::RAW) { return "RAW"; } else if (!var.Var().IsInitialized()) { return "nullptr"; } else { return framework::ToTypeName(var.Var().Type()); } } // Bind Methods void BindImperative(py::module *m_ptr) { auto &m = *m_ptr; py::class_ backward_strategy( m, "BackwardStrategy", R"DOC( BackwardStrategy is a descriptor of a how to run the backward process. Now it has: 1. :code:`sort_sum_gradient`, which will sum the gradient by the reverse order of trace. Examples: .. code-block:: python import numpy as np import paddle.fluid as fluid from paddle.fluid import FC x = np.ones([2, 2], np.float32) with fluid.dygraph.guard(): inputs2 = [] for _ in range(10): inputs2.append(fluid.dygraph.base.to_variable(x)) ret2 = fluid.layers.sums(inputs2) loss2 = fluid.layers.reduce_sum(ret2) backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy.sort_sum_gradient = True loss2.backward(backward_strategy) )DOC"); backward_strategy.def(py::init()) .def_property("sort_sum_gradient", [](const imperative::detail::BackwardStrategy &self) { return self.sorted_sum_gradient_; }, [](imperative::detail::BackwardStrategy &self, bool sorted_sum_gradient) { self.sorted_sum_gradient_ = sorted_sum_gradient; }); m.def("start_imperative_gperf_profiler", []() { imperative::StartProfile(); }); m.def("stop_imperative_gperf_profiler", []() { imperative::StopProfile(); }); m.def("_is_dygraph_debug_enabled", []() { return imperative::IsDebugEnabled(); }); m.def("_dygraph_debug_level", []() { return imperative::GetDebugLevel(); }); py::class_>( m, "VarBase", R"DOC()DOC") .def_static("_alive_vars", &imperative::VarBase::AliveVarNames) .def("__init__", [](imperative::VarBase &self, const std::string &name, framework::proto::VarType::Type type, framework::proto::VarType::Type dtype, const std::vector &dims, bool stop_gradient, bool persistable) { new (&self) imperative::VarBase(name); self.SetPersistable(persistable); self.SetType(type); self.SetDataType(dtype); self.SetStopGradient(stop_gradient); if (type == framework::proto::VarType::LOD_TENSOR) { auto *tensor = self.MutableVar()->GetMutable(); tensor->Resize(framework::make_ddim(dims)); } }) .def("_run_backward", [](imperative::VarBase &self, const imperative::detail::BackwardStrategy &bckst, const imperative::Tracer &tracer) { // TODO(jiabin): when we impl more backward execution we can select // them imperative::Engine *engine = tracer.GetDefaultEngine(); VLOG(3) << "Start backward"; engine->Init(&self, bckst); engine->Execute(); VLOG(3) << "Finish backward"; }, py::call_guard()) .def("_grad_name", &imperative::VarBase::GradVarName) .def("_grad_value", [](imperative::VarBase &self) { return self.MutableGradVar()->Get(); }, py::return_value_policy::reference) .def("_clear_gradient", &imperative::VarBase::ClearGradient) .def("_grad_ivar", [](const imperative::VarBase &self) { auto &grad_var = self.GradVarBase(); if (grad_var && grad_var->Var().IsInitialized()) { return grad_var; } else { return std::shared_ptr(nullptr); } }, py::return_value_policy::copy) .def("_copy_to", [](const imperative::VarBase &self, const platform::CPUPlace &place, bool blocking) { return self.NewVarBase(place, blocking); }, py::return_value_policy::copy) .def("_copy_to", [](const imperative::VarBase &self, const platform::CUDAPlace &place, bool blocking) { return self.NewVarBase(place, blocking); }, py::return_value_policy::copy) .def("value", [](imperative::VarBase &self) { return self.MutableVar(); }, py::return_value_policy::reference) .def_property("name", &imperative::VarBase::Name, &imperative::VarBase::SetName) .def_property_readonly( "shape", [](imperative::VarBase &self) { if (self.Var().IsType()) { return framework::vectorize( self.Var().Get().dims()); } else { VLOG(2) << "It is meaningless to get shape of variable type " << GetTypeName(self); return std::vector(); } }) .def_property_readonly("type", &imperative::VarBase::Type) .def_property_readonly("dtype", &imperative::VarBase::DataType) .def_property("persistable", &imperative::VarBase::Persistable, &imperative::VarBase::SetPersistable) .def_property("stop_gradient", &imperative::VarBase::StopGradient, &imperative::VarBase::SetStopGradient); py::class_ layer(m, "Layer"); layer.def(py::init<>()) .def("forward", [](imperative::Layer &self, const std::vector> &inputs) { return self.Forward(inputs); }); py::class_(m, "Tracer", "") .def("__init__", [](imperative::Tracer &self) { new (&self) imperative::Tracer(); }) .def("trace", [](imperative::Tracer &self, const std::string &type, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, framework::AttributeMap attrs, const platform::CUDAPlace &place, bool trace_backward) { auto ins_map = ConvertToNameVarBaseMap(ins); auto outs_map = ConvertToNameVarBaseMap(outs); { py::gil_scoped_release release; self.TraceOp(type, std::move(ins_map), std::move(outs_map), std::move(attrs), place, trace_backward); } }) .def("trace", [](imperative::Tracer &self, const std::string &type, const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs, framework::AttributeMap attrs, const platform::CPUPlace &place, bool trace_backward) { auto ins_map = ConvertToNameVarBaseMap(ins); auto outs_map = ConvertToNameVarBaseMap(outs); { py::gil_scoped_release release; self.TraceOp(type, std::move(ins_map), std::move(outs_map), std::move(attrs), place, trace_backward); } }); // define parallel context py::class_ parallel_strategy( m, "ParallelStrategy", ""); parallel_strategy.def(py::init()) .def_property( "nranks", [](const imperative::ParallelStrategy &self) { return self.nranks_; }, [](imperative::ParallelStrategy &self, int nranks) { self.nranks_ = nranks; }) .def_property("local_rank", [](const imperative::ParallelStrategy &self) { return self.local_rank_; }, [](imperative::ParallelStrategy &self, int local_rank) { self.local_rank_ = local_rank; }) .def_property( "trainer_endpoints", [](const imperative::ParallelStrategy &self) { return self.trainer_endpoints_; }, [](imperative::ParallelStrategy &self, std::vector eps) { self.trainer_endpoints_ = eps; }) .def_property("current_endpoint", [](const imperative::ParallelStrategy &self) { return self.current_endpoint_; }, [](imperative::ParallelStrategy &self, const std::string &ep) { self.current_endpoint_ = ep; }); #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) py::class_ nccl_ctx(m, "NCCLParallelContext"); nccl_ctx .def(py::init()) .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); }); #endif } } // namespace pybind } // namespace paddle