提交 08e99006 编写于 作者: F fengjiayi

Fix bugs

上级 57c95c79
......@@ -50,12 +50,12 @@ public:
VarDesc *Proto() { return &desc_; }
void SetShape(const vector<int64_t> &dims) {
void SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
}
void SetDataType(int type_id) {
desc_.mutable_lod_tensor()->set_data_type(const_cast<DataType>(type_id));
desc_.mutable_lod_tensor()->set_data_type(static_cast<DataType>(type_id));
}
std::vector<int64_t> Shape() {
......@@ -86,7 +86,8 @@ public:
int32_t Parent() const { return desc_->parent_idx(); }
VarDescBind *NewVar(const std::string &name) {
VarDescBind *NewVar(py::bytes name_bytes) {
std::string name = name_bytes;
need_update_ = true;
auto it = vars_.find(name);
PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name);
......@@ -224,16 +225,15 @@ void BindBlockDesc(py::module &m) {
&BlockDescBind::AppendOp,
py::return_value_policy::reference)
.def("new_var",
[](BlockDesc &self) { return self.add_vars(); },
&BlockDescBind::NewVar,
py::return_value_policy::reference);
}
void BindVarDsec(py::module &m) {
py::class_<VarDesc>(m, "VarDesc", "")
.def(py::init<>())
.def("set_shape", VarDescBind::SetShape)
.def("set_data_type", VarDescBind::SetDataType)
.def("shape", VarDescBind::Shape);
py::class_<VarDescBind>(m, "VarDesc", "")
.def("set_shape", &VarDescBind::SetShape)
.def("set_data_type", &VarDescBind::SetDataType)
.def("shape", &VarDescBind::Shape);
}
void BindOpDesc(py::module &m) {
......
......@@ -33,7 +33,7 @@ class TestVarDesc(unittest.TestCase):
def test_shape(self):
program_desc = core.ProgramDesc.instance()
block = program_desc.root_block()
var = block.new_var()
var = block.new_var('my_var')
src_shape = [3, 2, 10, 8]
var.set_shape(src_shape)
res_shape = var.shape()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册