From f6e1d959d2f54a8baa183d76c8134f27c60edcba Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 18 Oct 2017 18:22:31 -0700 Subject: [PATCH] Expose VarDesc::persistable to Python (#4911) --- paddle/framework/var_desc.h | 4 ++++ paddle/pybind/protobuf.cc | 23 +++++++++++++---------- python/paddle/v2/framework/framework.py | 20 +++++++++++++++++++- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h index 688a46f8398..af4c26ca0a7 100644 --- a/paddle/framework/var_desc.h +++ b/paddle/framework/var_desc.h @@ -79,6 +79,10 @@ class VarDescBind { void SetType(VarDesc::VarType type) { desc_.set_type(type); } + bool Persistable() const { return desc_.persistable(); } + + void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } + private: const TensorDesc &tensor_desc() const; TensorDesc *mutable_tensor_desc(); diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index d9647717d2a..a4fb9b7c075 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -202,16 +202,19 @@ void BindVarDsec(py::module &m) { .def("set_lod_level", &VarDescBind::SetLoDLevel) .def("type", &VarDescBind::GetType) .def("set_type", &VarDescBind::SetType) - .def("serialize_to_string", [](VarDescBind &var_desc) -> py::bytes { - const VarDesc *desc = var_desc.Proto(); - PADDLE_ENFORCE(desc->IsInitialized(), - "VarDesc has not been initialized."); - std::string res; - PADDLE_ENFORCE( - desc->SerializeToString(&res), - "Serialize VarDesc Error. This could be a bug of Paddle."); - return res; - }); + .def("serialize_to_string", + [](VarDescBind &var_desc) -> py::bytes { + const VarDesc *desc = var_desc.Proto(); + PADDLE_ENFORCE(desc->IsInitialized(), + "VarDesc has not been initialized."); + std::string res; + PADDLE_ENFORCE( + desc->SerializeToString(&res), + "Serialize VarDesc Error. This could be a bug of Paddle."); + return res; + }) + .def("persistable", &VarDescBind::Persistable) + .def("set_persistable", &VarDescBind::SetPersistable); py::enum_(var_desc, "VarType", "") .value("LOD_TENSOR", VarDesc::LOD_TENSOR) diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index 5a8ded46ea4..8c63ca9644f 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -15,6 +15,7 @@ class Variable(object): shape=None, dtype=None, lod_level=None, + persistable=False, **kwargs): self.block = block @@ -70,6 +71,17 @@ class Variable(object): "lod_level is {2}. They are not " "matched".format(self.name, self.lod_level, lod_level)) + if persistable is not None: + if is_new_var: + self.desc.set_persistable(persistable) + else: + if persistable != self.persistable: + raise ValueError( + "Variable {0} has been created before." + "The previous persistable is {1}; the new " + "persistable is {2}. They are not matched".format( + self.name, self.persistable, persistable)) + self.block.vars[name] = self self.op = None @@ -80,6 +92,10 @@ class Variable(object): __repr__ = __str__ + @property + def persistable(self): + return self.desc.persistable() + @property def name(self): return self.desc.name() @@ -445,7 +461,9 @@ class Parameter(Variable): if each < 0: raise ValueError("Parameter shape should not be related with " "batch-size") - Variable.__init__(self, block, shape=shape, dtype=dtype, **kwargs) + + Variable.__init__( + self, block, persistable=True, shape=shape, dtype=dtype, **kwargs) self.trainable = kwargs.get('trainable', True) self.init_attr = kwargs.get('initialize_attr', { 'type': 'uniform_random', -- GitLab