From 8b833d5a8ada43ba8b049665d5c6161eeb0c5d65 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 26 Dec 2016 16:18:22 +0800 Subject: [PATCH] Add load/save method for Parameter --- paddle/api/PaddleAPI.h | 4 ++++ paddle/api/Parameter.cpp | 8 ++++++++ paddle/api/test/.gitignore | 6 ++++++ paddle/api/test/testGradientMachine.py | 4 ++++ 4 files changed, 22 insertions(+) create mode 100644 paddle/api/test/.gitignore diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 155e3e3afe..bc1b22e187 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -551,6 +551,10 @@ public: ParameterConfig* getConfig(); void setValueUpdated(); + bool save(const std::string& filename) const; + + bool load(const std::string& filename) const; + private: static Parameter* createFromRawPtr(void* ptr); static Parameter* createFromSharedPtr(void* ptr); diff --git a/paddle/api/Parameter.cpp b/paddle/api/Parameter.cpp index 4eed00a84a..9cfa2e35f5 100644 --- a/paddle/api/Parameter.cpp +++ b/paddle/api/Parameter.cpp @@ -70,3 +70,11 @@ ParameterConfig* Parameter::getConfig() { size_t Parameter::getID() const { return m->getPtr()->getID(); } void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); } + +bool Parameter::save(const std::string& filename) const { + return m->getPtr()->save(filename); +} + +bool Parameter::load(const std::string& filename) const { + return m->getPtr()->load(filename); +} diff --git a/paddle/api/test/.gitignore b/paddle/api/test/.gitignore new file mode 100644 index 0000000000..ef37ef4167 --- /dev/null +++ b/paddle/api/test/.gitignore @@ -0,0 +1,6 @@ +___fc_layer_0__.w0 +___fc_layer_0__.wbias +_hidden1.w0 +_hidden1.wbias +_hidden2.w0 +_hidden2.wbias diff --git a/paddle/api/test/testGradientMachine.py b/paddle/api/test/testGradientMachine.py index b81eafa967..4b705f66ec 100644 --- a/paddle/api/test/testGradientMachine.py +++ b/paddle/api/test/testGradientMachine.py @@ -45,6 +45,7 @@ class TestGradientMachine(unittest.TestCase): assert isinstance(val, swig_paddle.Vector) arr = numpy.full((len(val), ), 0.1, dtype="float32") val.copyFromNumpyArray(arr) + self.assertTrue(param.save(param.getName())) param_config = param.getConfig().toProto() assert isinstance(param_config, paddle.proto.ParameterConfig_pb2.ParameterConfig) @@ -92,6 +93,9 @@ class TestGradientMachine(unittest.TestCase): self.assertTrue(self.isCalled) + for param in machine.getParameters(): + self.assertTrue(param.load(param.getName())) + def test_train_one_pass(self): conf_file_path = './testTrainConfig.py' trainer_config = swig_paddle.TrainerConfig.createFromTrainerConfigFile( -- GitLab