protobuf.cc 13.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
14
#include "paddle/fluid/pybind/protobuf.h"
15

Y
Yu Yang 已提交
16
#include <deque>
Y
Yu Yang 已提交
17
#include <iostream>
L
Luo Tao 已提交
18 19
#include <string>
#include <tuple>
20

Y
Yi Wang 已提交
21 22 23 24
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
25

Y
Yu Yang 已提交
26 27 28 29 30 31
// Cast boost::variant for PyBind.
// Copy from
// https://github.com/pybind/pybind11/issues/576#issuecomment-269563199
namespace pybind11 {
namespace detail {

P
peizhilin 已提交
32 33 34 35 36 37 38 39
#if !defined(PYBIND11_HIDDEN)
#ifdef _WIN32
#define PYBIND11_HIDDEN __declspec(dllexport)
#else
#define PYBIND11_HIDDEN __attribute__((visibility("hidden")))
#endif
#endif

Y
Yu Yang 已提交
40
// Can be replaced by a generic lambda in C++14
P
peizhilin 已提交
41
struct PYBIND11_HIDDEN paddle_variant_caster_visitor
42
    : public boost::static_visitor<handle> {
Y
Yu Yang 已提交
43 44 45
  return_value_policy policy;
  handle parent;

46
  paddle_variant_caster_visitor(return_value_policy policy, handle parent)
Y
Yu Yang 已提交
47 48 49 50 51 52 53 54 55
      : policy(policy), parent(parent) {}

  template <class T>
  handle operator()(T const &src) const {
    return make_caster<T>::cast(src, policy, parent);
  }
};

template <class Variant>
56
struct paddle_variant_caster;
Y
Yu Yang 已提交
57 58

template <template <class...> class V, class... Ts>
59
struct paddle_variant_caster<V<Ts...>> {
Y
Yu Yang 已提交
60 61
  using Type = V<Ts...>;

Y
Yu Yang 已提交
62 63
  template <typename T>
  typename std::enable_if<
Y
Yu Yang 已提交
64
      !std::is_same<T, boost::detail::variant::void_>::value, bool>::type
Y
Yu Yang 已提交
65
  try_load(handle src, bool convert) {
Y
Yu Yang 已提交
66 67 68
    auto caster = make_caster<T>();
    if (!load_success_ && caster.load(src, convert)) {
      load_success_ = true;
S
seiriosPlus 已提交
69 70 71 72

      if (std::is_same<T, std::vector<float>>::value) {
        auto caster_ints = make_caster<std::vector<int64_t>>();
        if (caster_ints.load(src, convert)) {
M
minqiyang 已提交
73 74 75
          VLOG(4) << "This value are floats and int64_ts satisfy "
                     "simultaneously, will set it's type to "
                     "std::vector<int64_t>";
S
seiriosPlus 已提交
76 77 78 79 80
          value = cast_op<std::vector<int64_t>>(caster_ints);
          return true;
        }
      }

Y
Yu Yang 已提交
81 82 83 84 85 86
      value = cast_op<T>(caster);
      return true;
    }
    return false;
  }

Y
Yu Yang 已提交
87 88 89 90 91 92 93
  template <typename T>
  typename std::enable_if<std::is_same<T, boost::detail::variant::void_>::value,
                          bool>::type
  try_load(handle src, bool convert) {
    return false;
  }

Y
Yu Yang 已提交
94 95 96 97 98 99
  bool load(handle src, bool convert) {
    auto unused = {false, try_load<Ts>(src, convert)...};
    (void)(unused);
    return load_success_;
  }

Y
Yu Yang 已提交
100
  static handle cast(Type const &src, return_value_policy policy,
Y
Yu Yang 已提交
101
                     handle parent) {
102
    paddle_variant_caster_visitor visitor(policy, parent);
Y
Yu Yang 已提交
103 104 105 106 107 108 109 110 111 112
    return boost::apply_visitor(visitor, src);
  }

  PYBIND11_TYPE_CASTER(Type, _("Variant"));
  bool load_success_{false};
};

// Add specialization for concrete variant type
template <class... Args>
struct type_caster<boost::variant<Args...>>
113
    : paddle_variant_caster<boost::variant<Args...>> {};
Y
Yu Yang 已提交
114 115 116 117

}  // namespace detail
}  // namespace pybind11

118
namespace paddle {
119
namespace pybind {
120

121
namespace pd = paddle::framework;
F
fengjiayi 已提交
122

Y
Yu Yang 已提交
123
template <typename T>
124 125
static pybind11::bytes SerializeMessage(
    T &self) {  // NOLINT due to pybind11 convention.
Y
Yu Yang 已提交
126 127 128 129 130 131 132
  // Check IsInitialized in Python
  std::string retv;
  PADDLE_ENFORCE(self.Proto()->SerializePartialToString(&retv),
                 "Cannot serialize message");
  return retv;
}

Y
Yu Yang 已提交
133
// Bind Methods
134 135 136
void BindProgramDesc(pybind11::module *m) {
  pybind11::class_<pd::ProgramDesc>(*m, "ProgramDesc", "")
      .def(pybind11::init<>())
Y
Yu Yang 已提交
137
      .def("__init__",
138 139
           [](pd::ProgramDesc &self, const pd::ProgramDesc &other) {
             new (&self) pd::ProgramDesc(other);
Y
Yu Yang 已提交
140
           })
141
      .def("__init__",
142
           [](pd::ProgramDesc &self, const pybind11::bytes &binary_str) {
143
             std::string str(binary_str);
144
             new (&self) pd::ProgramDesc(str);
145
           })
146 147 148 149 150
      .def("append_block", &pd::ProgramDesc::AppendBlock,
           pybind11::return_value_policy::reference)
      .def("block", &pd::ProgramDesc::MutableBlock,
           pybind11::return_value_policy::reference)
      .def("num_blocks", &pd::ProgramDesc::Size)
151
      .def("flush", &pd::ProgramDesc::Flush)
152 153
      .def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames)
      .def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames)
154
      .def("serialize_to_string", SerializeMessage<pd::ProgramDesc>)
155
      .def("parse_from_string",
156 157
           [](pd::ProgramDesc &program_desc, const std::string &data) {
             pd::proto::ProgramDesc *desc = program_desc.Proto();
158 159 160
             PADDLE_ENFORCE(desc->ParseFromString(data),
                            "Fail to parse ProgramDesc from string. This could "
                            "be a bug of Paddle.");
X
version  
Xin Pan 已提交
161 162 163 164
           })
      .def("_version", [](pd::ProgramDesc &self) -> int64_t {
        return self.Proto()->version().version();
      });
165 166
}

167 168 169 170 171
void BindBlockDesc(pybind11::module *m) {
  pybind11::class_<pd::BlockDesc>(*m, "BlockDesc", "")
      .def_property_readonly("id", &pd::BlockDesc::ID)
      .def_property_readonly("parent", &pd::BlockDesc::Parent)
      .def("get_forward_block_idx", &pd::BlockDesc::ForwardBlockID)
W
Wu Yi 已提交
172
      .def("_set_forward_block_idx", &pd::BlockDesc::SetForwardBlockID)
173 174
      .def("append_op", &pd::BlockDesc::AppendOp,
           pybind11::return_value_policy::reference)
W
Wu Yi 已提交
175
      .def("_prepend_op", &pd::BlockDesc::PrependOp,
176
           pybind11::return_value_policy::reference)
W
Wu Yi 已提交
177
      .def("_insert_op", &pd::BlockDesc::InsertOp,
178
           pybind11::return_value_policy::reference)
W
Wu Yi 已提交
179
      .def("_remove_op", &pd::BlockDesc::RemoveOp)
D
dongzhihong 已提交
180
      .def("var",
181
           [](pd::BlockDesc &self, pybind11::bytes byte_name) {
F
fengjiayi 已提交
182
             std::string name = byte_name;
D
dongzhihong 已提交
183
             return self.Var(name);
F
fengjiayi 已提交
184
           },
185
           pybind11::return_value_policy::reference)
Q
Qiao Longfei 已提交
186
      .def("has_var",
187
           [](pd::BlockDesc &self, pybind11::bytes byte_name) {
Q
Qiao Longfei 已提交
188 189
             std::string name = byte_name;
             return self.HasVar(name);
T
wip  
typhoonzero 已提交
190
           },
191
           pybind11::return_value_policy::reference)
W
Wu Yi 已提交
192
      .def("_rename_var",
193 194
           [](pd::BlockDesc &self, const pybind11::bytes &byte_name,
              const pybind11::bytes &byte_name_new) {
T
typhoonzero 已提交
195 196
             std::string name = byte_name;
             std::string new_name = byte_name_new;
T
wip  
typhoonzero 已提交
197
             self.RenameVar(name, new_name);
Q
Qiao Longfei 已提交
198
           })
F
fengjiayi 已提交
199
      .def("has_var_recursive",
200
           [](pd::BlockDesc &self, pybind11::bytes byte_name) {
F
fengjiayi 已提交
201 202 203
             std::string name = byte_name;
             return self.HasVarRecursive(name);
           })
D
Dong Zhihong 已提交
204
      .def("find_var",
205
           [](pd::BlockDesc &self, pybind11::bytes byte_name) {
F
fengjiayi 已提交
206
             std::string name = byte_name;
D
Dong Zhihong 已提交
207
             return self.FindVar(name);
F
fengjiayi 已提交
208
           },
209
           pybind11::return_value_policy::reference)
F
fengjiayi 已提交
210
      .def("find_var_recursive",
211
           [](pd::BlockDesc &self, pybind11::bytes byte_name) {
F
fengjiayi 已提交
212 213 214
             std::string name = byte_name;
             return self.FindVarRecursive(name);
           },
215
           pybind11::return_value_policy::reference)
W
Wu Yi 已提交
216
      .def("_remove_var",
217
           [](pd::BlockDesc &self, pybind11::bytes byte_name) {
L
Luo Tao 已提交
218 219 220
             std::string name = byte_name;
             return self.RemoveVar(name);
           },
221 222 223 224 225 226
           pybind11::return_value_policy::reference)
      .def("all_vars", &pd::BlockDesc::AllVars,
           pybind11::return_value_policy::reference)
      .def("op_size", &pd::BlockDesc::OpSize)
      .def("op", &pd::BlockDesc::Op, pybind11::return_value_policy::reference)
      .def("serialize_to_string", SerializeMessage<pd::BlockDesc>);
227 228
}

229 230
void BindVarDsec(pybind11::module *m) {
  pybind11::class_<pd::VarDesc> var_desc(*m, "VarDesc", "");
Y
Yu Yang 已提交
231
  var_desc
M
minqiyang 已提交
232
      .def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference)
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
      .def("set_name", &pd::VarDesc::SetName)
      .def("set_shape", &pd::VarDesc::SetShape)
      .def("set_shapes", &pd::VarDesc::SetShapes)
      .def("set_dtype", &pd::VarDesc::SetDataType)
      .def("set_dtypes", &pd::VarDesc::SetDataTypes)
      .def("shape", &pd::VarDesc::GetShape,
           pybind11::return_value_policy::reference)
      .def("shapes", &pd::VarDesc::GetShapes,
           pybind11::return_value_policy::reference)
      .def("dtype", &pd::VarDesc::GetDataType,
           pybind11::return_value_policy::reference)
      .def("dtypes", &pd::VarDesc::GetDataTypes,
           pybind11::return_value_policy::reference)
      .def("lod_level", &pd::VarDesc::GetLoDLevel)
      .def("lod_levels", &pd::VarDesc::GetLoDLevels,
           pybind11::return_value_policy::reference)
      .def("set_lod_level", &pd::VarDesc::SetLoDLevel)
      .def("set_lod_levels", &pd::VarDesc::SetLoDLevels)
      .def("type", &pd::VarDesc::GetType)
      .def("set_type", &pd::VarDesc::SetType)
      .def("serialize_to_string", SerializeMessage<pd::VarDesc>)
      .def("persistable", &pd::VarDesc::Persistable)
      .def("set_persistable", &pd::VarDesc::SetPersistable);
Y
Yu Yang 已提交
256

257 258
  pybind11::enum_<pd::proto::VarType::Type>(var_desc, "VarType", "")
      .value("BOOL", pd::proto::VarType::BOOL)
259
      .value("UINT8", pd::proto::VarType::UINT8)
Q
qingqing01 已提交
260
      .value("INT8", pd::proto::VarType::INT8)
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
      .value("INT16", pd::proto::VarType::INT16)
      .value("INT32", pd::proto::VarType::INT32)
      .value("INT64", pd::proto::VarType::INT64)
      .value("FP16", pd::proto::VarType::FP16)
      .value("FP32", pd::proto::VarType::FP32)
      .value("FP64", pd::proto::VarType::FP64)
      .value("LOD_TENSOR", pd::proto::VarType::LOD_TENSOR)
      .value("SELECTED_ROWS", pd::proto::VarType::SELECTED_ROWS)
      .value("FEED_MINIBATCH", pd::proto::VarType::FEED_MINIBATCH)
      .value("FETCH_LIST", pd::proto::VarType::FETCH_LIST)
      .value("STEP_SCOPES", pd::proto::VarType::STEP_SCOPES)
      .value("LOD_RANK_TABLE", pd::proto::VarType::LOD_RANK_TABLE)
      .value("LOD_TENSOR_ARRAY", pd::proto::VarType::LOD_TENSOR_ARRAY)
      .value("PLACE_LIST", pd::proto::VarType::PLACE_LIST)
      .value("READER", pd::proto::VarType::READER)
      .value("RAW", pd::proto::VarType::RAW);
277 278
}

279 280 281 282
void BindOpDesc(pybind11::module *m) {
  pybind11::enum_<pd::proto::AttrType>(*m, "AttrType", "")
      .value("INT", pd::proto::AttrType::INT)
      .value("INTS", pd::proto::AttrType::INTS)
T
tangwei12 已提交
283 284
      .value("LONG", pd::proto::AttrType::LONG)
      .value("LONGS", pd::proto::AttrType::LONGS)
285 286 287 288 289 290
      .value("FLOAT", pd::proto::AttrType::FLOAT)
      .value("FLOATS", pd::proto::AttrType::FLOATS)
      .value("STRING", pd::proto::AttrType::STRING)
      .value("STRINGS", pd::proto::AttrType::STRINGS)
      .value("BOOL", pd::proto::AttrType::BOOLEAN)
      .value("BOOLS", pd::proto::AttrType::BOOLEANS)
Y
Yancey1989 已提交
291 292
      .value("BLOCK", pd::proto::AttrType::BLOCK)
      .value("BLOCKS", pd::proto::AttrType::BLOCKS);
Y
Yu Yang 已提交
293

294
  pybind11::class_<pd::OpDesc> op_desc(*m, "OpDesc", "");
F
fengjiayi 已提交
295
  op_desc
296 297 298 299 300 301 302 303 304 305 306 307 308
      .def("__init__", [](pd::OpDesc &self) { new (&self) pd::OpDesc(); },
           pybind11::return_value_policy::reference)
      .def("copy_from", &pd::OpDesc::CopyFrom)
      .def("type", &pd::OpDesc::Type)
      .def("set_type", &pd::OpDesc::SetType)
      .def("input", &pd::OpDesc::Input)
      .def("input_names", &pd::OpDesc::InputNames)
      .def("output", &pd::OpDesc::Output)
      .def("output_names", &pd::OpDesc::OutputNames)
      .def("set_input", &pd::OpDesc::SetInput)
      .def("set_output", &pd::OpDesc::SetOutput)
      .def("input_arg_names", &pd::OpDesc::InputArgumentNames)
      .def("output_arg_names", &pd::OpDesc::OutputArgumentNames)
W
Wu Yi 已提交
309 310
      .def("_rename_input", &pd::OpDesc::RenameInput)
      .def("_rename_output", &pd::OpDesc::RenameOutput)
311 312 313
      .def("has_attr", &pd::OpDesc::HasAttr)
      .def("attr_type", &pd::OpDesc::GetAttrType)
      .def("attr_names", &pd::OpDesc::AttrNames)
W
Wu Yi 已提交
314
      .def("_set_attr", &pd::OpDesc::SetAttr)
315 316
      .def("attr", &pd::OpDesc::GetAttr)
      .def("set_block_attr", &pd::OpDesc::SetBlockAttr)
317
      .def("set_blocks_attr", &pd::OpDesc::SetBlocksAttr)
T
typhoonzero 已提交
318
      .def("set_serialized_attr",
319 320
           [](pd::OpDesc &self, const std::string &name,
              const pybind11::bytes &seriralized) {
T
typhoonzero 已提交
321 322 323
             std::string ser(seriralized);
             self.SetAttr(name, ser);
           })
W
Wu Yi 已提交
324 325
      .def("_block_attr_id", &pd::OpDesc::GetBlockAttrId)
      .def("_blocks_attr_ids", &pd::OpDesc::GetBlocksAttrIds)
326 327 328
      .def("check_attrs", &pd::OpDesc::CheckAttrs)
      .def("infer_shape", &pd::OpDesc::InferShape)
      .def("infer_var_type", &pd::OpDesc::InferVarType)
329
      .def("set_is_target", &pd::OpDesc::SetIsTarget)
330 331 332
      .def("serialize_to_string", SerializeMessage<pd::OpDesc>)
      .def("block", &pd::OpDesc::Block,
           pybind11::return_value_policy::reference);
333
}
Y
Yu Yang 已提交
334

335
}  // namespace pybind
336
}  // namespace paddle