提交 9fa7c930 编写于 作者: Y Yu Yang

Merge branch 'feature/pybind_for_protobuf_desc' of github.com:reyoung/Paddle...

Merge branch 'feature/pybind_for_protobuf_desc' of github.com:reyoung/Paddle into feature/pybind_for_protobuf_desc
...@@ -106,7 +106,7 @@ enum DataType { ...@@ -106,7 +106,7 @@ enum DataType {
message LoDTensorDesc { message LoDTensorDesc {
required DataType data_type = 1; required DataType data_type = 1;
repeated int32 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
optional int32 lod_level = 3 [ default = 0 ]; optional int32 lod_level = 3 [ default = 0 ];
} }
......
...@@ -47,12 +47,24 @@ class VarDescBind; ...@@ -47,12 +47,24 @@ class VarDescBind;
class VarDescBind { class VarDescBind {
public: public:
explicit VarDescBind(const std::string &name) { var_desc_.set_name(name); } explicit VarDescBind(const std::string &name) { desc_.set_name(name); }
VarDesc *Proto() { return &var_desc_; } VarDesc *Proto() { return &desc_; }
void SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
}
void SetDataType(int type_id) {
desc_.mutable_lod_tensor()->set_data_type(static_cast<DataType>(type_id));
}
std::vector<int64_t> Shape() {
return RepeatedToVector(desc_.lod_tensor().dims());
}
private: private:
VarDesc var_desc_; VarDesc desc_;
}; };
class OpDescBind { class OpDescBind {
...@@ -170,7 +182,8 @@ public: ...@@ -170,7 +182,8 @@ public:
int32_t Parent() const { return desc_->parent_idx(); } int32_t Parent() const { return desc_->parent_idx(); }
VarDescBind *NewVar(const std::string &name) { VarDescBind *NewVar(py::bytes name_bytes) {
std::string name = name_bytes;
need_update_ = true; need_update_ = true;
auto it = vars_.find(name); auto it = vars_.find(name);
PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name); PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name);
...@@ -303,32 +316,15 @@ void BindBlockDesc(py::module &m) { ...@@ -303,32 +316,15 @@ void BindBlockDesc(py::module &m) {
&BlockDescBind::AppendOp, &BlockDescBind::AppendOp,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("new_var", .def("new_var",
[](BlockDesc &self) { return self.add_vars(); }, &BlockDescBind::NewVar,
py::return_value_policy::reference); py::return_value_policy::reference);
} }
void BindVarDsec(py::module &m) { void BindVarDsec(py::module &m) {
py::class_<VarDesc>(m, "VarDesc", ""); py::class_<VarDescBind>(m, "VarDesc", "")
// using namespace paddle::framework; // NOLINT .def("set_shape", &VarDescBind::SetShape)
// py::class_<VarDesc>(m, "VarDesc", "") .def("set_data_type", &VarDescBind::SetDataType)
// .def(py::init<>()) .def("shape", &VarDescBind::Shape);
// .def("set_name",
// [](VarDesc &self, const std::string &name) { self.set_name(name);
// })
// .def("set_shape",
// [](VarDesc &self, const std::vector<int64_t> &dims) {
// VectorToRepeated(dims,
// self.mutable_lod_tensor()->mutable_dims());
// })
// .def("set_data_type",
// [](VarDesc &self, int type_id) {
// LoDTensorDesc *lod_tensor_desc = self.mutable_lod_tensor();
// lod_tensor_desc->set_data_type(static_cast<DataType>(type_id));
// })
// .def("shape", [](VarDesc &self) {
// const LoDTensorDesc &lod_tensor_desc = self.lod_tensor();
// return RepeatedToVector(lod_tensor_desc.dims());
// });
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
......
...@@ -52,7 +52,7 @@ class TestVarDesc(unittest.TestCase): ...@@ -52,7 +52,7 @@ class TestVarDesc(unittest.TestCase):
def test_shape(self): def test_shape(self):
program_desc = core.ProgramDesc.instance() program_desc = core.ProgramDesc.instance()
block = program_desc.root_block() block = program_desc.root_block()
var = block.new_var() var = block.new_var('my_var')
src_shape = [3, 2, 10, 8] src_shape = [3, 2, 10, 8]
var.set_shape(src_shape) var.set_shape(src_shape)
res_shape = var.shape() res_shape = var.shape()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册