diff --git a/paddle/api/Arguments.cpp b/paddle/api/Arguments.cpp index 0cafbd896e2d88aee4406bd0305878ce489bc18d..41beed38a87601cb57072c8966cd0fd2ea156524 100644 --- a/paddle/api/Arguments.cpp +++ b/paddle/api/Arguments.cpp @@ -137,6 +137,10 @@ void Arguments::setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError) { a.cpuSequenceDims = m->cast(vec->getSharedPtr()); } +float Arguments::sumCosts() const { + return paddle::Argument::sumCosts(m->outputs); +} + int64_t Arguments::getBatchSize(size_t idx) const throw(RangeError) { auto& a = m->getArg(idx); return a.getBatchSize(); diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 364d19f9414430709108824dce75a1007332d824..f5af8b0035b44d97832dd90ca2eeba079503715c 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -450,6 +450,8 @@ public: IVector* vec) throw(RangeError); void setSlotSequenceDim(size_t idx, IVector* vec) throw(RangeError); + float sumCosts() const; + private: static Arguments* createByPaddleArgumentVector(void* ptr); void* getInternalArgumentsPtr() const; @@ -546,6 +548,10 @@ public: ParameterConfig* getConfig(); void setValueUpdated(); + bool save(const std::string& filename) const; + + bool load(const std::string& filename) const; + size_t getSize() const; private: diff --git a/paddle/api/Parameter.cpp b/paddle/api/Parameter.cpp index ddc00d8d1af4c58d7e2233423bea916408bee92b..19f7a898d6b8d3d02c5654559dcb86728266731e 100644 --- a/paddle/api/Parameter.cpp +++ b/paddle/api/Parameter.cpp @@ -57,4 +57,12 @@ size_t Parameter::getID() const { return m->getPtr()->getID(); } void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); } +bool Parameter::save(const std::string& filename) const { + return m->getPtr()->save(filename); +} + +bool Parameter::load(const std::string& filename) const { + return m->getPtr()->load(filename); +} + size_t Parameter::getSize() const { return m->getPtr()->getSize(); } diff --git a/paddle/api/test/.gitignore b/paddle/api/test/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b7948824a1eab119140dd9bea20276c303fe4af1 --- /dev/null +++ b/paddle/api/test/.gitignore @@ -0,0 +1,2 @@ +*.w0 +*.wbias diff --git a/paddle/api/test/testArguments.py b/paddle/api/test/testArguments.py index 8cabecd242fb4eb98c0fe468687ef179245e4535..a04a805d7a64ef906c8388f1241b9ef823e4d9e0 100644 --- a/paddle/api/test/testArguments.py +++ b/paddle/api/test/testArguments.py @@ -22,6 +22,8 @@ class TestArguments(unittest.TestCase): args = swig_paddle.Arguments.createArguments(1) args.setSlotValue(0, m) + self.assertAlmostEqual(27.0, args.sumCosts()) + mat = args.getSlotValue(0) assert isinstance(mat, swig_paddle.Matrix) np_mat = mat.toNumpyMatInplace() diff --git a/paddle/api/test/testGradientMachine.py b/paddle/api/test/testGradientMachine.py index b81eafa9673ca34f1b7e06401098d55bdb1b35a5..4b705f66eccd267f326fe0662a17b33a09fda982 100644 --- a/paddle/api/test/testGradientMachine.py +++ b/paddle/api/test/testGradientMachine.py @@ -45,6 +45,7 @@ class TestGradientMachine(unittest.TestCase): assert isinstance(val, swig_paddle.Vector) arr = numpy.full((len(val), ), 0.1, dtype="float32") val.copyFromNumpyArray(arr) + self.assertTrue(param.save(param.getName())) param_config = param.getConfig().toProto() assert isinstance(param_config, paddle.proto.ParameterConfig_pb2.ParameterConfig) @@ -92,6 +93,9 @@ class TestGradientMachine(unittest.TestCase): self.assertTrue(self.isCalled) + for param in machine.getParameters(): + self.assertTrue(param.load(param.getName())) + def test_train_one_pass(self): conf_file_path = './testTrainConfig.py' trainer_config = swig_paddle.TrainerConfig.createFromTrainerConfigFile(