From 05ab22c332e615f3c81f4d4b2c9b47f71229c71c Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 21 Dec 2016 14:22:31 +0800 Subject: [PATCH] A simplest train file for mnist added. --- demo/mnist/api_train.py | 16 +++++++++++++++- paddle/api/PaddleAPI.h | 2 +- paddle/api/ParameterUpdater.cpp | 4 ++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index ef8b20a48dc..425c5f897a9 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -58,11 +58,25 @@ def main(): for _ in xrange(100): updater.startPass() + outArgs = api.Arguments.createArguments(0) train_data_generator = input_order_converter( read_from_mnist(train_file)) - for data_batch in generator_to_batch(train_data_generator, 128): + for batch_id, data_batch in enumerate( + generator_to_batch(train_data_generator, 256)): trainRole = updater.startBatch(len(data_batch)) + def update_callback(param): + updater.update(param) + + m.forwardBackward( + converter(data_batch), outArgs, trainRole, update_callback) + + cost_vec = outArgs.getSlotValue(0) + cost_vec = cost_vec.copyToNumpyMat() + cost = cost_vec.sum() / len(data_batch) + print 'Batch id', batch_id, 'with cost=', cost + updater.finishBatch(cost) + updater.finishPass() m.finish() diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 165997ba349..cc49e6a09d5 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -799,7 +799,7 @@ public: void finishPass(); - PassType startBatch(int64_t batchSize); + PassType startBatch(size_t batchSize); void finishBatch(float cost); diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index e5d07b81782..fba47620249 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -36,8 +36,8 @@ void ParameterUpdater::startPass() { m->updater->startPass(); } void ParameterUpdater::finishPass() { m->updater->finishPass(); } -PassType ParameterUpdater::startBatch(int64_t batchSize) { - return m->updater->startBatch(batchSize); +PassType ParameterUpdater::startBatch(size_t batchSize) { + return m->updater->startBatch((int64_t)batchSize); } void ParameterUpdater::finishBatch(float cost) { -- GitLab