imperative.cc 13.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/pybind/imperative.h"
16

17
#include <Python.h>
18 19 20 21
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
22
#include <memory>
23 24
#include <unordered_map>
#include <utility>
25

26
#include "paddle/fluid/framework/block_desc.h"
27 28
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/profiler.h"
29
#include "paddle/fluid/imperative/tracer.h"
M
minqiyang 已提交
30
#include "paddle/fluid/imperative/type_defs.h"
31

32 33
#include "paddle/fluid/pybind/pybind_boost_headers.h"

34 35 36
namespace paddle {
namespace pybind {

37 38
namespace py = ::pybind11;

39 40 41 42
class Layer : public imperative::Layer {
 public:
  using imperative::Layer::Layer;  // Inherit constructors

43 44 45 46 47
  std::vector<std::shared_ptr<imperative::VarBase>> Forward(
      const std::vector<std::shared_ptr<imperative::VarBase>> &inputs)
      override {
    PYBIND11_OVERLOAD(std::vector<std::shared_ptr<imperative::VarBase>>, Layer,
                      Forward,
48 49 50 51 52 53 54 55 56 57 58
                      inputs);  // NOLINT
  }
};

class PYBIND11_HIDDEN PyOpBase : public imperative::OpBase {
 public:
  using imperative::OpBase::OpBase;  // Inherit constructors

  PyOpBase(const std::string &name) : OpBase(name) {}
};

59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
// Function like obj.attr_name in Python.
static PyObject *GetPythonAttribute(PyObject *obj, const char *attr_name) {
  // NOTE(zjl): PyObject_GetAttrString would return nullptr when attr_name
  // is not inside obj, but it would also set the error flag of Python.
  // If the error flag is set in C++, C++ code would not raise Exception,
  // but Python would raise Exception once C++ call ends.
  // To avoid unexpected Exception raised in Python, we check whether
  // attribute exists before calling PyObject_GetAttrString.
  //
  // Caution: PyObject_GetAttrString would increase reference count of PyObject.
  // Developer should call Py_DECREF manually after the attribute is not used.
  if (PyObject_HasAttrString(obj, attr_name)) {
    return PyObject_GetAttrString(obj, attr_name);
  } else {
    return nullptr;
  }
}

template <typename T>
static T PyObjectCast(PyObject *obj) {
  try {
    return py::cast<T>(py::handle(obj));
  } catch (py::cast_error &) {
    PADDLE_THROW("Python object is not type of %s", typeid(T).name());
  }
}

// NOTE(zjl): py::handle is a very light wrapper of PyObject *.
// Unlike py::object, py::handle does not change reference count of PyObject *.
static std::vector<std::shared_ptr<imperative::VarBase>>
GetVarBaseListFromPyHandle(const py::handle &handle) {
  PyObject *py_obj = handle.ptr();  // get underlying PyObject
  // Python None is not nullptr in C++!
  if (!py_obj || py_obj == Py_None) {
    return {};
  }

  const char *kIVarField = "_ivar";
  PyObject *py_ivar = GetPythonAttribute(py_obj, kIVarField);
  std::vector<std::shared_ptr<imperative::VarBase>> result;

  if (py_ivar) {  // Variable
    result.emplace_back(
        PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    Py_DECREF(py_ivar);
  } else if (PyList_Check(py_obj)) {  // List of Variable
    size_t len = PyList_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
      PyObject *py_ivar =
          PyObject_GetAttrString(PyList_GET_ITEM(py_obj, i), kIVarField);
      PADDLE_ENFORCE_NOT_NULL(py_ivar);
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
      Py_DECREF(py_ivar);
    }
  } else if (PyTuple_Check(py_obj)) {  // Tuple of Variable
    size_t len = PyTuple_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
      PyObject *py_ivar =
          PyObject_GetAttrString(PyTuple_GET_ITEM(py_obj, i), kIVarField);
      PADDLE_ENFORCE_NOT_NULL(py_ivar);
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
      Py_DECREF(py_ivar);
    }
  } else {
    PADDLE_THROW(
        "unsupported type %s, must be Variable, List[Variable] or "
        "tuple[Variable]",
        py::str(handle));
  }

  PADDLE_ENFORCE(PyErr_Occurred() == nullptr,
                 py::str(py::handle(PyErr_Occurred())));

  return result;
}

using PyVarBaseMap = std::unordered_map<std::string, py::handle>;

static imperative::VarBasePtrMap ConvertToVarBasePtrMap(
    const PyVarBaseMap &map) {
  imperative::VarBasePtrMap result;
  for (auto &pair : map) {
    auto var_vec = GetVarBaseListFromPyHandle(pair.second);
    if (!var_vec.empty()) {
      result.emplace(pair.first, std::move(var_vec));
    }
  }
  return result;
}

153
// Bind Methods
154 155 156 157
void BindImperative(pybind11::module *m_ptr) {
  auto &m = *m_ptr;

  py::class_<imperative::detail::BackwardStrategy> backward_strategy(
158 159 160 161 162 163 164 165
      m, "BackwardStrategy", R"DOC(

    BackwardStrategy is a descriptor of a how to run the backward process. Now it has:

    1. :code:`sort_sum_gradient`, which will sum the gradient by the reverse order of trace.

    Examples:

L
lujun 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
        .. code-block:: python

          import numpy as np
          import paddle.fluid as fluid
          from paddle.fluid import FC

          x = np.ones([2, 2], np.float32)
          with fluid.dygraph.guard():
              inputs2 = []
              for _ in range(10):
                  inputs2.append(fluid.dygraph.base.to_variable(x))
              ret2 = fluid.layers.sums(inputs2)
              loss2 = fluid.layers.reduce_sum(ret2)
              backward_strategy = fluid.dygraph.BackwardStrategy()
              backward_strategy.sort_sum_gradient = True
              loss2.backward(backward_strategy)
182
      )DOC");
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
  backward_strategy.def(py::init())
      .def_property("sort_sum_gradient",
                    [](const imperative::detail::BackwardStrategy &self) {
                      return self.sorted_sum_gradient_;
                    },
                    [](imperative::detail::BackwardStrategy &self,
                       bool sorted_sum_gradient) {
                      self.sorted_sum_gradient_ = sorted_sum_gradient;
                    });

  m.def("start_imperative_gperf_profiler",
        []() { imperative::StartProfile(); });

  m.def("stop_imperative_gperf_profiler", []() { imperative::StopProfile(); });

Z
Zeng Jinle 已提交
198 199 200 201
  m.def("_is_dygraph_debug_enabled",
        []() { return imperative::IsDebugEnabled(); });
  m.def("_dygraph_debug_level", []() { return imperative::GetDebugLevel(); });

202 203
  py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
      m, "VarBase", R"DOC()DOC")
Z
Zeng Jinle 已提交
204
      .def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
      .def(
          py::init<const std::string &, paddle::framework::proto::VarType::Type,
                   const std::vector<int64_t>, const paddle::platform::CPUPlace,
                   bool, bool>())
      .def(
          py::init<const std::string &, paddle::framework::proto::VarType::Type,
                   const std::vector<int64_t>,
                   const paddle::platform::CUDAPlace, bool, bool>())
      .def("_run_backward",
           [](imperative::VarBase &self,
              const imperative::detail::BackwardStrategy &bckst) {
             self.RunBackward(bckst);
           })
      .def("_grad_name", &imperative::VarBase::GradName)
      .def("_grad_value", &imperative::VarBase::GradValue)
      .def("_clear_gradient", &imperative::VarBase::ClearGradient)
      .def("_grad_ivar",
           [](const imperative::VarBase &self) { return self.grads_; },
           py::return_value_policy::reference)
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::CPUPlace &place,
              bool blocking) {
             return self.NewVarBase(place, blocking).release();
           },
           py::return_value_policy::take_ownership)
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::CUDAPlace &place,
              bool blocking) {
             return self.NewVarBase(place, blocking).release();
           },
           py::return_value_policy::take_ownership)
      .def("value",
           [](const imperative::VarBase &self) { return self.var_.get(); },
           py::return_value_policy::reference)
      .def_property("name", &imperative::VarBase::Name,
                    &imperative::VarBase::SetName)
      .def_property_readonly("shape", &imperative::VarBase::Shape)
      .def_property_readonly("dtype", &imperative::VarBase::DataType)
      .def_property("persistable", &imperative::VarBase::IsPersistable,
                    &imperative::VarBase::SetPersistable)
      .def_property("stop_gradient", &imperative::VarBase::IsStopGradient,
                    &imperative::VarBase::SetStopGradient);

  py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
      .def(py::init<const std::string &>())
      .def("register_backward_hooks",
           [](imperative::OpBase &self, const py::object &callable) {
             self.RegisterBackwardHooks(callable);
           })
      .def_property("_trace_id",
                    [](const imperative::OpBase &self) {
                      py::gil_scoped_release release;
                      return self.trace_id_;
                    },
                    [](imperative::OpBase &self, int trace_id) {
                      py::gil_scoped_release release;
                      self.trace_id_ = trace_id;
                    },
                    py::return_value_policy::reference)
      .def_property_readonly("type", &imperative::OpBase::Type);

  py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(m, "Layer");
  layer.def(py::init<>())
268 269 270 271 272
      .def("forward",
           [](imperative::Layer &self,
              const std::vector<std::shared_ptr<imperative::VarBase>> &inputs) {
             return self.Forward(inputs);
           });
273

274 275 276 277 278
  // NOTE(zjl): Tracer use PyVarBaseMap as its parameter but not VarBasePtrMap.
  // We call Python C-API to convert PyVarBaseMap to VarBasePtrMap, instead
  // making conversion in Python code. This speed up Tracer.trace() about 6%
  // in ptb model and make time cost in Python to be nearly zero.
  py::class_<imperative::Tracer>(m, "Tracer", "")
279
      .def("__init__",
280
           [](imperative::Tracer &self, framework::BlockDesc *root_block) {
M
minqiyang 已提交
281
             new (&self) imperative::Tracer(root_block);
282
           })
M
minqiyang 已提交
283
      .def("trace",
284
           [](imperative::Tracer &self, imperative::OpBase *op,
285
              const PyVarBaseMap &inputs, const PyVarBaseMap &outputs,
286
              framework::AttributeMap attrs_map,
M
minqiyang 已提交
287 288
              const platform::CPUPlace expected_place,
              const bool stop_gradient = false) {
289 290 291 292 293 294 295
             auto ins = ConvertToVarBasePtrMap(inputs);
             auto outs = ConvertToVarBasePtrMap(outputs);
             {
               py::gil_scoped_release release;
               self.Trace(op, std::move(ins), &outs, attrs_map, expected_place,
                          stop_gradient);
             }
M
minqiyang 已提交
296
           })
297
      .def("trace", [](imperative::Tracer &self, imperative::OpBase *op,
298
                       const PyVarBaseMap &inputs, const PyVarBaseMap &outputs,
299 300 301
                       framework::AttributeMap attrs_map,
                       const platform::CUDAPlace expected_place,
                       const bool stop_gradient = false) {
302 303 304 305 306 307 308
        auto ins = ConvertToVarBasePtrMap(inputs);
        auto outs = ConvertToVarBasePtrMap(outputs);
        {
          py::gil_scoped_release release;
          self.Trace(op, std::move(ins), &outs, attrs_map, expected_place,
                     stop_gradient);
        }
309
      });
310 311

  // define parallel context
312 313 314
  py::class_<imperative::ParallelStrategy> parallel_strategy(
      m, "ParallelStrategy", "");
  parallel_strategy.def(py::init())
315 316
      .def_property(
          "nranks",
317 318
          [](const imperative::ParallelStrategy &self) { return self.nranks_; },
          [](imperative::ParallelStrategy &self, int nranks) {
319 320 321
            self.nranks_ = nranks;
          })
      .def_property("local_rank",
322
                    [](const imperative::ParallelStrategy &self) {
323 324
                      return self.local_rank_;
                    },
325
                    [](imperative::ParallelStrategy &self, int local_rank) {
326 327 328 329
                      self.local_rank_ = local_rank;
                    })
      .def_property(
          "trainer_endpoints",
330
          [](const imperative::ParallelStrategy &self) {
331 332
            return self.trainer_endpoints_;
          },
333
          [](imperative::ParallelStrategy &self, std::vector<std::string> eps) {
334 335 336
            self.trainer_endpoints_ = eps;
          })
      .def_property("current_endpoint",
337
                    [](const imperative::ParallelStrategy &self) {
338 339
                      return self.current_endpoint_;
                    },
340 341
                    [](imperative::ParallelStrategy &self,
                       const std::string &ep) { self.current_endpoint_ = ep; });
342
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
343 344
  py::class_<imperative::NCCLParallelContext> nccl_ctx(m,
                                                       "NCCLParallelContext");
345 346

  nccl_ctx
347 348 349
      .def(py::init<const imperative::ParallelStrategy &,
                    const platform::CUDAPlace &>())
      .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); });
350
#endif
351 352 353 354
}

}  // namespace pybind
}  // namespace paddle