提交 08e99006 编写于 作者: F fengjiayi

Fix bugs

上级 57c95c79
...@@ -50,12 +50,12 @@ public: ...@@ -50,12 +50,12 @@ public:
VarDesc *Proto() { return &desc_; } VarDesc *Proto() { return &desc_; }
void SetShape(const vector<int64_t> &dims) { void SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
} }
void SetDataType(int type_id) { void SetDataType(int type_id) {
desc_.mutable_lod_tensor()->set_data_type(const_cast<DataType>(type_id)); desc_.mutable_lod_tensor()->set_data_type(static_cast<DataType>(type_id));
} }
std::vector<int64_t> Shape() { std::vector<int64_t> Shape() {
...@@ -86,7 +86,8 @@ public: ...@@ -86,7 +86,8 @@ public:
int32_t Parent() const { return desc_->parent_idx(); } int32_t Parent() const { return desc_->parent_idx(); }
VarDescBind *NewVar(const std::string &name) { VarDescBind *NewVar(py::bytes name_bytes) {
std::string name = name_bytes;
need_update_ = true; need_update_ = true;
auto it = vars_.find(name); auto it = vars_.find(name);
PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name); PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name);
...@@ -224,16 +225,15 @@ void BindBlockDesc(py::module &m) { ...@@ -224,16 +225,15 @@ void BindBlockDesc(py::module &m) {
&BlockDescBind::AppendOp, &BlockDescBind::AppendOp,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("new_var", .def("new_var",
[](BlockDesc &self) { return self.add_vars(); }, &BlockDescBind::NewVar,
py::return_value_policy::reference); py::return_value_policy::reference);
} }
void BindVarDsec(py::module &m) { void BindVarDsec(py::module &m) {
py::class_<VarDesc>(m, "VarDesc", "") py::class_<VarDescBind>(m, "VarDesc", "")
.def(py::init<>()) .def("set_shape", &VarDescBind::SetShape)
.def("set_shape", VarDescBind::SetShape) .def("set_data_type", &VarDescBind::SetDataType)
.def("set_data_type", VarDescBind::SetDataType) .def("shape", &VarDescBind::Shape);
.def("shape", VarDescBind::Shape);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
......
...@@ -33,7 +33,7 @@ class TestVarDesc(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestVarDesc(unittest.TestCase):
def test_shape(self): def test_shape(self):
program_desc = core.ProgramDesc.instance() program_desc = core.ProgramDesc.instance()
block = program_desc.root_block() block = program_desc.root_block()
var = block.new_var() var = block.new_var('my_var')
src_shape = [3, 2, 10, 8] src_shape = [3, 2, 10, 8]
var.set_shape(src_shape) var.set_shape(src_shape)
res_shape = var.shape() res_shape = var.shape()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册