diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 155e3e3afe6d915ab934a588e5d8a882be58e2a0..bc1b22e187fc2ea7290cff62b90786c68d15a985 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -551,6 +551,10 @@ public: ParameterConfig* getConfig(); void setValueUpdated(); + bool save(const std::string& filename) const; + + bool load(const std::string& filename) const; + private: static Parameter* createFromRawPtr(void* ptr); static Parameter* createFromSharedPtr(void* ptr); diff --git a/paddle/api/Parameter.cpp b/paddle/api/Parameter.cpp index 4eed00a84a695f2c48ff93b33419ae2b3dd03768..9cfa2e35f5569bfafc80ceac7199a401c58c5991 100644 --- a/paddle/api/Parameter.cpp +++ b/paddle/api/Parameter.cpp @@ -70,3 +70,11 @@ ParameterConfig* Parameter::getConfig() { size_t Parameter::getID() const { return m->getPtr()->getID(); } void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); } + +bool Parameter::save(const std::string& filename) const { + return m->getPtr()->save(filename); +} + +bool Parameter::load(const std::string& filename) const { + return m->getPtr()->load(filename); +} diff --git a/paddle/api/test/.gitignore b/paddle/api/test/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..ef37ef416791c279e8696b1564b749ff6c6130a5 --- /dev/null +++ b/paddle/api/test/.gitignore @@ -0,0 +1,6 @@ +___fc_layer_0__.w0 +___fc_layer_0__.wbias +_hidden1.w0 +_hidden1.wbias +_hidden2.w0 +_hidden2.wbias diff --git a/paddle/api/test/testGradientMachine.py b/paddle/api/test/testGradientMachine.py index b81eafa9673ca34f1b7e06401098d55bdb1b35a5..4b705f66eccd267f326fe0662a17b33a09fda982 100644 --- a/paddle/api/test/testGradientMachine.py +++ b/paddle/api/test/testGradientMachine.py @@ -45,6 +45,7 @@ class TestGradientMachine(unittest.TestCase): assert isinstance(val, swig_paddle.Vector) arr = numpy.full((len(val), ), 0.1, dtype="float32") val.copyFromNumpyArray(arr) + self.assertTrue(param.save(param.getName())) param_config = param.getConfig().toProto() assert isinstance(param_config, paddle.proto.ParameterConfig_pb2.ParameterConfig) @@ -92,6 +93,9 @@ class TestGradientMachine(unittest.TestCase): self.assertTrue(self.isCalled) + for param in machine.getParameters(): + self.assertTrue(param.load(param.getName())) + def test_train_one_pass(self): conf_file_path = './testTrainConfig.py' trainer_config = swig_paddle.TrainerConfig.createFromTrainerConfigFile(