protobuf.cc 17.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
14
#include "paddle/fluid/pybind/protobuf.h"
15

Y
Yu Yang 已提交
16
#include <deque>
Y
Yu Yang 已提交
17
#include <iostream>
L
Luo Tao 已提交
18 19
#include <string>
#include <tuple>
20

Y
Yi Wang 已提交
21
#include "paddle/fluid/framework/block_desc.h"
22
#include "paddle/fluid/framework/ir/graph_helper.h"
Y
Yi Wang 已提交
23
#include "paddle/fluid/framework/op_desc.h"
24
#include "paddle/fluid/framework/process_mesh_desc.h"
Y
Yi Wang 已提交
25 26
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
27
#include "paddle/fluid/framework/version.h"
28
#include "paddle/fluid/jit/property.h"
29
#include "paddle/fluid/pybind/pybind_variant_caster.h"
Y
Yu Yang 已提交
30

31 32
namespace py = pybind11;

33
namespace paddle {
34
namespace pybind {
35

36 37 38
PyTypeObject *g_vartype_pytype = nullptr;
PyTypeObject *g_blockdesc_pytype = nullptr;

39
namespace pd = paddle::framework;
40
namespace jit = paddle::jit;
F
fengjiayi 已提交
41

Y
Yu Yang 已提交
42
template <typename T>
43 44
static pybind11::bytes SerializeMessage(
    T &self) {  // NOLINT due to pybind11 convention.
Y
Yu Yang 已提交
45 46
  // Check IsInitialized in Python
  std::string retv;
47 48
  PADDLE_ENFORCE_EQ(self.Proto()->SerializePartialToString(&retv),
                    true,
49 50
                    platform::errors::InvalidArgument(
                        "Failed to serialize input Desc to string."));
Y
Yu Yang 已提交
51 52 53
  return retv;
}

54 55 56 57 58 59 60 61 62
template <typename T>
static void DeserializeMessage(T *self, const std::string &str) {
  PADDLE_ENFORCE_EQ(
      self->Proto()->ParsePartialFromString(str),
      true,
      platform::errors::InvalidArgument("Failed to parse pb from string"));
  return;
}

Y
Yu Yang 已提交
63
// Bind Methods
64 65 66
void BindProgramDesc(pybind11::module *m) {
  pybind11::class_<pd::ProgramDesc>(*m, "ProgramDesc", "")
      .def(pybind11::init<>())
Y
Yu Yang 已提交
67
      .def("__init__",
68 69
           [](pd::ProgramDesc &self, const pd::ProgramDesc &other) {
             new (&self) pd::ProgramDesc(other);
Y
Yu Yang 已提交
70
           })
71
      .def("__init__",
72
           [](pd::ProgramDesc &self, const pybind11::bytes &binary_str) {
73
             std::string str(binary_str);
74
             new (&self) pd::ProgramDesc(str);
75
           })
76 77
      .def("append_block",
           &pd::ProgramDesc::AppendBlock,
78
           pybind11::return_value_policy::reference)
79 80
      .def("block",
           &pd::ProgramDesc::MutableBlock,
81 82
           pybind11::return_value_policy::reference)
      .def("num_blocks", &pd::ProgramDesc::Size)
83
      .def("flush", &pd::ProgramDesc::Flush)
84 85
      .def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames)
      .def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames)
86
      .def("serialize_to_string", SerializeMessage<pd::ProgramDesc>)
87
      .def("parse_from_string",
88 89
           [](pd::ProgramDesc &program_desc, const std::string &data) {
             pd::proto::ProgramDesc *desc = program_desc.Proto();
90
             PADDLE_ENFORCE_EQ(
91 92
                 desc->ParseFromString(data),
                 true,
93 94
                 platform::errors::InvalidArgument(
                     "Failed to parse ProgramDesc from binary string."));
X
version  
Xin Pan 已提交
95
           })
96 97 98 99 100 101
      .def(
          "_set_version",
          [](pd::ProgramDesc &self, int64_t version) {
            return self.SetVersion(version);
          },
          pybind11::arg("version") = pd::kCurProgramVersion)
102
      .def("_version",
103 104 105 106
           [](pd::ProgramDesc &self) -> int64_t { return self.Version(); })
      .def("get_op_deps", [](const framework::ProgramDesc &program) {
        return framework::ir::GetOpDependencies(program);
      });
107 108
}

109 110 111
void BindProcessMeshDesc(pybind11::module *m) {
  pybind11::class_<pd::ProcessMeshDesc>(*m, "ProcessMeshDesc", "")
      .def(pybind11::init<const std::vector<int32_t> &,
112 113
                          const std::vector<int32_t> &,
                          int32_t>())
114 115 116 117 118 119 120
      .def_property_readonly("id", &pd::ProcessMeshDesc::ID)
      .def_property_readonly("parent", &pd::ProcessMeshDesc::Parent)
      .def_property_readonly("topology", &pd::ProcessMeshDesc::Topology)
      .def_property_readonly("process_group",
                             &pd::ProcessMeshDesc::ProcessGroup);
}

121
void BindBlockDesc(pybind11::module *m) {
122 123 124
  pybind11::class_<pd::BlockDesc> blockdesc(*m, "BlockDesc", "");
  g_blockdesc_pytype = (PyTypeObject *)blockdesc.ptr();  // NOLINT
  blockdesc.def_property_readonly("id", &pd::BlockDesc::ID)
125 126
      .def_property_readonly("parent", &pd::BlockDesc::Parent)
      .def("get_forward_block_idx", &pd::BlockDesc::ForwardBlockID)
W
Wu Yi 已提交
127
      .def("_set_forward_block_idx", &pd::BlockDesc::SetForwardBlockID)
128 129
      .def("append_op",
           &pd::BlockDesc::AppendOp,
130
           pybind11::return_value_policy::reference)
131 132
      .def("_prepend_op",
           &pd::BlockDesc::PrependOp,
133
           pybind11::return_value_policy::reference)
134 135
      .def("_insert_op",
           &pd::BlockDesc::InsertOp,
136
           pybind11::return_value_policy::reference)
W
Wu Yi 已提交
137
      .def("_remove_op", &pd::BlockDesc::RemoveOp)
138 139 140 141 142 143 144 145 146 147 148 149 150 151
      .def(
          "var",
          [](pd::BlockDesc &self, pybind11::bytes byte_name) {
            std::string name = byte_name;
            return self.Var(name);
          },
          pybind11::return_value_policy::reference)
      .def(
          "has_var",
          [](pd::BlockDesc &self, pybind11::bytes byte_name) {
            std::string name = byte_name;
            return self.HasVar(name);
          },
          pybind11::return_value_policy::reference)
W
Wu Yi 已提交
152
      .def("_rename_var",
153 154
           [](pd::BlockDesc &self,
              const pybind11::bytes &byte_name,
155
              const pybind11::bytes &byte_name_new) {
T
typhoonzero 已提交
156 157
             std::string name = byte_name;
             std::string new_name = byte_name_new;
T
wip  
typhoonzero 已提交
158
             self.RenameVar(name, new_name);
Q
Qiao Longfei 已提交
159
           })
F
fengjiayi 已提交
160
      .def("has_var_recursive",
161
           [](pd::BlockDesc &self, pybind11::bytes byte_name) {
F
fengjiayi 已提交
162 163 164
             std::string name = byte_name;
             return self.HasVarRecursive(name);
           })
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
      .def(
          "find_var",
          [](pd::BlockDesc &self, pybind11::bytes byte_name) {
            std::string name = byte_name;
            return self.FindVar(name);
          },
          pybind11::return_value_policy::reference)
      .def(
          "find_var_recursive",
          [](pd::BlockDesc &self, pybind11::bytes byte_name) {
            std::string name = byte_name;
            return self.FindVarRecursive(name);
          },
          pybind11::return_value_policy::reference)
      .def(
          "_remove_var",
          [](pd::BlockDesc &self, pybind11::bytes byte_name) {
            std::string name = byte_name;
            return self.RemoveVar(name);
          },
          pybind11::return_value_policy::reference)
186 187
      .def("all_vars",
           &pd::BlockDesc::AllVars,
188 189 190
           pybind11::return_value_policy::reference)
      .def("op_size", &pd::BlockDesc::OpSize)
      .def("op", &pd::BlockDesc::Op, pybind11::return_value_policy::reference)
191 192
      .def("serialize_to_string", SerializeMessage<pd::BlockDesc>)
      .def("_move_from", &pd::BlockDesc::MoveFrom);
193 194
}

195 196
void BindVarDsec(pybind11::module *m) {
  pybind11::class_<pd::VarDesc> var_desc(*m, "VarDesc", "");
W
WangZhen 已提交
197
  var_desc.def(pybind11::init<const std::string &>())
M
minqiyang 已提交
198
      .def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference)
199 200 201
      .def("set_name", &pd::VarDesc::SetName)
      .def("set_shape", &pd::VarDesc::SetShape)
      .def("set_shapes", &pd::VarDesc::SetShapes)
H
hong 已提交
202
      .def("get_shape", &pd::VarDesc::GetShape)
203 204
      .def("set_dtype", &pd::VarDesc::SetDataType)
      .def("set_dtypes", &pd::VarDesc::SetDataTypes)
205 206
      .def("shape",
           &pd::VarDesc::GetShape,
207
           pybind11::return_value_policy::reference)
208 209
      .def("shapes",
           &pd::VarDesc::GetShapes,
210
           pybind11::return_value_policy::reference)
211 212
      .def("dtype",
           &pd::VarDesc::GetDataType,
213
           pybind11::return_value_policy::reference)
214 215
      .def("element_size",
           &pd::VarDesc::ElementSize,
216
           pybind11::return_value_policy::reference)
217 218
      .def("dtypes",
           &pd::VarDesc::GetDataTypes,
219 220
           pybind11::return_value_policy::reference)
      .def("lod_level", &pd::VarDesc::GetLoDLevel)
221 222
      .def("lod_levels",
           &pd::VarDesc::GetLoDLevels,
223 224 225 226 227 228 229
           pybind11::return_value_policy::reference)
      .def("set_lod_level", &pd::VarDesc::SetLoDLevel)
      .def("set_lod_levels", &pd::VarDesc::SetLoDLevels)
      .def("type", &pd::VarDesc::GetType)
      .def("set_type", &pd::VarDesc::SetType)
      .def("serialize_to_string", SerializeMessage<pd::VarDesc>)
      .def("persistable", &pd::VarDesc::Persistable)
H
Huihuang Zheng 已提交
230
      .def("set_persistable", &pd::VarDesc::SetPersistable)
231 232 233 234 235 236 237 238
      .def("is_parameter", &pd::VarDesc::IsParameter)
      .def("set_is_parameter", &pd::VarDesc::SetIsParameter)
      .def("clear_is_parameter", &pd::VarDesc::ClearIsParameter)
      .def("has_is_parameter", &pd::VarDesc::HasIsParameter)
      .def("stop_gradient", &pd::VarDesc::StopGradient)
      .def("set_stop_gradient", &pd::VarDesc::SetStopGradient)
      .def("clear_stop_gradient", &pd::VarDesc::ClearStopGradient)
      .def("has_stop_gradient", &pd::VarDesc::HasStopGradient)
H
Huihuang Zheng 已提交
239
      .def("need_check_feed", &pd::VarDesc::NeedCheckFeed)
240 241 242 243 244
      .def("set_need_check_feed", &pd::VarDesc::SetNeedCheckFeed)
      .def("has_attr", &pd::VarDesc::HasAttr)
      .def("attr_names", &pd::VarDesc::AttrNames)
      .def("_set_attr", &pd::VarDesc::SetAttr)
      .def("remove_attr", &pd::VarDesc::RemoveAttr)
245
      .def("id", &pd::VarDesc::Id)
246 247
      .def("original_id", &pd::VarDesc::OriginalId)
      .def("set_original_id", &pd::VarDesc::SetOriginalId)
248
      .def("attr", &pd::VarDesc::GetAttr);
Y
Yu Yang 已提交
249

250 251 252
  pybind11::enum_<pd::proto::VarType::Type> vartype(var_desc, "VarType", "");
  g_vartype_pytype = (PyTypeObject *)vartype.ptr();  // NOLINT
  vartype.value("BOOL", pd::proto::VarType::BOOL)
253
      .value("UINT8", pd::proto::VarType::UINT8)
Q
qingqing01 已提交
254
      .value("INT8", pd::proto::VarType::INT8)
255 256 257 258 259 260
      .value("INT16", pd::proto::VarType::INT16)
      .value("INT32", pd::proto::VarType::INT32)
      .value("INT64", pd::proto::VarType::INT64)
      .value("FP16", pd::proto::VarType::FP16)
      .value("FP32", pd::proto::VarType::FP32)
      .value("FP64", pd::proto::VarType::FP64)
261
      .value("BF16", pd::proto::VarType::BF16)
262 263
      .value("COMPLEX64", pd::proto::VarType::COMPLEX64)
      .value("COMPLEX128", pd::proto::VarType::COMPLEX128)
264 265 266 267 268 269 270 271 272
      .value("LOD_TENSOR", pd::proto::VarType::LOD_TENSOR)
      .value("SELECTED_ROWS", pd::proto::VarType::SELECTED_ROWS)
      .value("FEED_MINIBATCH", pd::proto::VarType::FEED_MINIBATCH)
      .value("FETCH_LIST", pd::proto::VarType::FETCH_LIST)
      .value("STEP_SCOPES", pd::proto::VarType::STEP_SCOPES)
      .value("LOD_RANK_TABLE", pd::proto::VarType::LOD_RANK_TABLE)
      .value("LOD_TENSOR_ARRAY", pd::proto::VarType::LOD_TENSOR_ARRAY)
      .value("PLACE_LIST", pd::proto::VarType::PLACE_LIST)
      .value("READER", pd::proto::VarType::READER)
S
Steffy-zxf 已提交
273 274 275 276
      .value("RAW", pd::proto::VarType::RAW)
      .value("STRING", pd::proto::VarType::STRING)
      .value("STRINGS", pd::proto::VarType::STRINGS)
      .value("VOCAB", pd::proto::VarType::VOCAB);
277 278
}

279 280 281 282
void BindOpDesc(pybind11::module *m) {
  pybind11::enum_<pd::proto::AttrType>(*m, "AttrType", "")
      .value("INT", pd::proto::AttrType::INT)
      .value("INTS", pd::proto::AttrType::INTS)
T
tangwei12 已提交
283 284
      .value("LONG", pd::proto::AttrType::LONG)
      .value("LONGS", pd::proto::AttrType::LONGS)
285 286 287 288 289 290
      .value("FLOAT", pd::proto::AttrType::FLOAT)
      .value("FLOATS", pd::proto::AttrType::FLOATS)
      .value("STRING", pd::proto::AttrType::STRING)
      .value("STRINGS", pd::proto::AttrType::STRINGS)
      .value("BOOL", pd::proto::AttrType::BOOLEAN)
      .value("BOOLS", pd::proto::AttrType::BOOLEANS)
Y
Yancey1989 已提交
291 292
      .value("BLOCK", pd::proto::AttrType::BLOCK)
      .value("BLOCKS", pd::proto::AttrType::BLOCKS);
Y
Yu Yang 已提交
293

294
  pybind11::class_<pd::OpDesc> op_desc(*m, "OpDesc", "");
F
fengjiayi 已提交
295
  op_desc
296
      .def(
297 298
          "__init__",
          [](pd::OpDesc &self) { new (&self) pd::OpDesc(); },
299
          pybind11::return_value_policy::reference)
300 301 302 303 304 305 306
      .def("copy_from", &pd::OpDesc::CopyFrom)
      .def("type", &pd::OpDesc::Type)
      .def("set_type", &pd::OpDesc::SetType)
      .def("input", &pd::OpDesc::Input)
      .def("input_names", &pd::OpDesc::InputNames)
      .def("output", &pd::OpDesc::Output)
      .def("output_names", &pd::OpDesc::OutputNames)
H
hong 已提交
307
      .def("set_input",
308 309
           [](pd::OpDesc &self,
              const std::string &name,
H
hong 已提交
310 311 312 313
              const std::vector<std::string> &vec_var_name) {
             self.SetInput(name, vec_var_name);
           })
      .def("set_output",
314 315
           [](pd::OpDesc &self,
              const std::string &name,
H
hong 已提交
316 317 318
              const std::vector<std::string> &vec_var_name) {
             self.SetOutput(name, vec_var_name);
           })
319
      .def("remove_output", &pd::OpDesc::RemoveOutput)
320
      .def("remove_input", &pd::OpDesc::RemoveInput)
321 322
      .def("input_arg_names", &pd::OpDesc::InputArgumentNames)
      .def("output_arg_names", &pd::OpDesc::OutputArgumentNames)
W
Wu Yi 已提交
323 324
      .def("_rename_input", &pd::OpDesc::RenameInput)
      .def("_rename_output", &pd::OpDesc::RenameOutput)
325 326 327
      .def("has_attr", &pd::OpDesc::HasAttr)
      .def("attr_type", &pd::OpDesc::GetAttrType)
      .def("attr_names", &pd::OpDesc::AttrNames)
W
Wu Yi 已提交
328
      .def("_set_attr", &pd::OpDesc::SetAttr)
329
      .def("remove_attr", &pd::OpDesc::RemoveAttr)
330 331
      .def("attr", &pd::OpDesc::GetAttr)
      .def("set_block_attr", &pd::OpDesc::SetBlockAttr)
332
      .def("set_blocks_attr", &pd::OpDesc::SetBlocksAttr)
T
typhoonzero 已提交
333
      .def("set_serialized_attr",
334 335
           [](pd::OpDesc &self,
              const std::string &name,
336
              const pybind11::bytes &seriralized) {
T
typhoonzero 已提交
337 338 339
             std::string ser(seriralized);
             self.SetAttr(name, ser);
           })
W
Wu Yi 已提交
340 341
      .def("_block_attr_id", &pd::OpDesc::GetBlockAttrId)
      .def("_blocks_attr_ids", &pd::OpDesc::GetBlocksAttrIds)
342 343 344
      .def("check_attrs", &pd::OpDesc::CheckAttrs)
      .def("infer_shape", &pd::OpDesc::InferShape)
      .def("infer_var_type", &pd::OpDesc::InferVarType)
345
      .def("set_is_target", &pd::OpDesc::SetIsTarget)
346
      .def("serialize_to_string", SerializeMessage<pd::OpDesc>)
347
      .def(
348 349
          "block",
          [](pd::OpDesc &self) { return self.Block(); },
350
          pybind11::return_value_policy::reference)
351
      .def("id", &pd::OpDesc::Id)
352 353
      .def("original_id", &pd::OpDesc::OriginalId)
      .def("set_original_id", &pd::OpDesc::SetOriginalId)
354 355
      .def("inputs", &pd::OpDesc::Inputs)
      .def("outputs", &pd::OpDesc::Outputs);
356
}
Y
Yu Yang 已提交
357

358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439
// Serialize Class Property
void BindJitProperty(pybind11::module *m) {
  pybind11::class_<jit::Property> property(*m, "Property");
  property
      .def(
          "__init__",
          [](jit::Property &self) { new (&self) jit::Property(); },
          pybind11::return_value_policy::reference)
      .def("size", &jit::Property::Size)
      .def("set_float",
           py::overload_cast<const float &>(&jit::Property::SetFloat),
           "set float",
           py::arg("val"))
      .def("set_float",
           py::overload_cast<const std::string &, const float &>(
               &jit::Property::SetFloat),
           "set float",
           py::arg("name"),
           py::arg("var"))
      .def("get_float",
           py::overload_cast<const int &>(&jit::Property::GetFloat, py::const_))
      .def("get_float",
           py::overload_cast<const std::string &>(&jit::Property::GetFloat,
                                                  py::const_))
      .def("set_floats",
           py::overload_cast<const std::vector<float> &>(
               &jit::Property::SetFloats),
           "set list of float",
           py::arg("vals"))
      .def("set_floats",
           py::overload_cast<const std::string &, const std::vector<float> &>(
               &jit::Property::SetFloats),
           "set list of float",
           py::arg("name"),
           py::arg("val"))
      .def("set_int",
           py::overload_cast<const int64_t &>(&jit::Property::SetInt64),
           "set int",
           py::arg("val"))
      .def("set_int",
           py::overload_cast<const std::string &, const int64_t &>(
               &jit::Property::SetInt64),
           "set int",
           py::arg("name"),
           py::arg("val"))
      .def("set_ints",
           py::overload_cast<const std::vector<int64_t> &>(
               &jit::Property::SetInt64s),
           "set list of int",
           py::arg("vals"))
      .def("set_ints",
           py::overload_cast<const std::string &, const std::vector<int64_t> &>(
               &jit::Property::SetInt64s),
           "set list of int",
           py::arg("name"),
           py::arg("val"))
      .def("set_string",
           py::overload_cast<const std::string &>(&jit::Property::SetString),
           "set string",
           py::arg("val"))
      .def("set_string",
           py::overload_cast<const std::string &, const std::string &>(
               &jit::Property::SetString),
           "set string",
           py::arg("name"),
           py::arg("val"))
      .def("set_strings",
           py::overload_cast<const std::vector<std::string> &>(
               &jit::Property::SetStrings),
           "set list of string",
           py::arg("vals"))
      .def("set_strings",
           py::overload_cast<const std::string &,
                             const std::vector<std::string> &>(
               &jit::Property::SetStrings),
           "set list of string",
           py::arg("name"),
           py::arg("val"))
      .def("serialize_to_string", SerializeMessage<jit::Property>)
      .def("parse_from_string", DeserializeMessage<jit::Property>);
}

440
}  // namespace pybind
441
}  // namespace paddle