提交 eeb7c8ad 编写于 作者: F fengjiayi

Compelete VarDescBind

上级 e05e27a7
......@@ -46,12 +46,24 @@ class VarDescBind;
class VarDescBind {
public:
explicit VarDescBind(const std::string &name) { var_desc_.set_name(name); }
explicit VarDescBind(const std::string &name) { desc_.set_name(name); }
VarDesc *Proto() { return &var_desc_; }
VarDesc *Proto() { return &desc_; }
void SetShape(const vector<int64_t> &dims) {
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
}
void SetDataType(int type_id) {
desc_.mutable_lod_tensor()->set_data_type(const_cast<DataType>(type_id));
}
std::vector<int64_t> Shape() {
return RepeatedToVector(desc_.lod_tensor().dims());
}
private:
VarDesc var_desc_;
VarDesc desc_;
};
class OpDescBind {
......@@ -217,27 +229,11 @@ void BindBlockDesc(py::module &m) {
}
void BindVarDsec(py::module &m) {
py::class_<VarDesc>(m, "VarDesc", "");
// using namespace paddle::framework; // NOLINT
// py::class_<VarDesc>(m, "VarDesc", "")
// .def(py::init<>())
// .def("set_name",
// [](VarDesc &self, const std::string &name) { self.set_name(name);
// })
// .def("set_shape",
// [](VarDesc &self, const std::vector<int64_t> &dims) {
// VectorToRepeated(dims,
// self.mutable_lod_tensor()->mutable_dims());
// })
// .def("set_data_type",
// [](VarDesc &self, int type_id) {
// LoDTensorDesc *lod_tensor_desc = self.mutable_lod_tensor();
// lod_tensor_desc->set_data_type(static_cast<DataType>(type_id));
// })
// .def("shape", [](VarDesc &self) {
// const LoDTensorDesc &lod_tensor_desc = self.lod_tensor();
// return RepeatedToVector(lod_tensor_desc.dims());
// });
py::class_<VarDesc>(m, "VarDesc", "")
.def(py::init<>())
.def("set_shape", VarDescBind::SetShape)
.def("set_data_type", VarDescBind::SetDataType)
.def("shape", VarDescBind::Shape);
}
void BindOpDesc(py::module &m) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册