diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index e508af7a0c571c7ccb1f07bcd5267d45481d0e6b..ef8b20a48dc6071cc0e0eff2aa22ca5cb5ba8816 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -45,7 +45,6 @@ def main(): config.model_config, api.CREATE_MODE_NORMAL, enable_types) assert isinstance(m, api.GradientMachine) init_parameter(network=m) - updater = api.ParameterUpdater.createLocalUpdater(opt_config) assert isinstance(updater, api.ParameterUpdater) updater.init(m) @@ -62,7 +61,7 @@ def main(): train_data_generator = input_order_converter( read_from_mnist(train_file)) for data_batch in generator_to_batch(train_data_generator, 128): - inArgs = converter(data_batch) + trainRole = updater.startBatch(len(data_batch)) updater.finishPass() diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index c074325091dee95ab4a07e2f06e2abeb7a5c76bc..165997ba3499f4ee5225aec70ef3dfd0cf1a6d3e 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -799,6 +799,12 @@ public: void finishPass(); + PassType startBatch(int64_t batchSize); + + void finishBatch(float cost); + + void update(Parameter* param); + private: ParameterUpdaterPrivate* m; }; diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index 4edec78b4a3d42f226a155663a03f84b1840e03b..e5d07b81782bf5212a043bea8d35c60a8b7ae4fa 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -35,3 +35,16 @@ void ParameterUpdater::init(const GradientMachine &gm) { void ParameterUpdater::startPass() { m->updater->startPass(); } void ParameterUpdater::finishPass() { m->updater->finishPass(); } + +PassType ParameterUpdater::startBatch(int64_t batchSize) { + return m->updater->startBatch(batchSize); +} + +void ParameterUpdater::finishBatch(float cost) { + m->updater->finishBatch(cost); +} + +void ParameterUpdater::update(Parameter *param) { + auto paddleParam = param->m->getPtr(); + m->updater->update(paddleParam); +}