From 20249e8e65aca17abaa9bbee9ab660e3573e21cf Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 21 Dec 2016 13:55:44 +0800 Subject: [PATCH] Try expose ParamUpdater::update --- demo/mnist/api_train.py | 3 +-- paddle/api/PaddleAPI.h | 6 ++++++ paddle/api/ParameterUpdater.cpp | 13 +++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index e508af7a0c5..ef8b20a48dc 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -45,7 +45,6 @@ def main(): config.model_config, api.CREATE_MODE_NORMAL, enable_types) assert isinstance(m, api.GradientMachine) init_parameter(network=m) - updater = api.ParameterUpdater.createLocalUpdater(opt_config) assert isinstance(updater, api.ParameterUpdater) updater.init(m) @@ -62,7 +61,7 @@ def main(): train_data_generator = input_order_converter( read_from_mnist(train_file)) for data_batch in generator_to_batch(train_data_generator, 128): - inArgs = converter(data_batch) + trainRole = updater.startBatch(len(data_batch)) updater.finishPass() diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index c074325091d..165997ba349 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -799,6 +799,12 @@ public: void finishPass(); + PassType startBatch(int64_t batchSize); + + void finishBatch(float cost); + + void update(Parameter* param); + private: ParameterUpdaterPrivate* m; }; diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 4edec78b4a3..e5d07b81782 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -35,3 +35,16 @@ void ParameterUpdater::init(const GradientMachine &gm) { void ParameterUpdater::startPass() { m->updater->startPass(); } void ParameterUpdater::finishPass() { m->updater->finishPass(); } + +PassType ParameterUpdater::startBatch(int64_t batchSize) { + return m->updater->startBatch(batchSize); +} + +void ParameterUpdater::finishBatch(float cost) { + m->updater->finishBatch(cost); +} + +void ParameterUpdater::update(Parameter *param) { + auto paddleParam = param->m->getPtr(); + m->updater->update(paddleParam); +} -- GitLab