diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index ef8b20a48dc6071cc0e0eff2aa22ca5cb5ba8816..425c5f897a9c254bdae2aa1a6e91a4ce7a69874e 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -58,11 +58,25 @@ def main(): for _ in xrange(100): updater.startPass() + outArgs = api.Arguments.createArguments(0) train_data_generator = input_order_converter( read_from_mnist(train_file)) - for data_batch in generator_to_batch(train_data_generator, 128): + for batch_id, data_batch in enumerate( + generator_to_batch(train_data_generator, 256)): trainRole = updater.startBatch(len(data_batch)) + def update_callback(param): + updater.update(param) + + m.forwardBackward( + converter(data_batch), outArgs, trainRole, update_callback) + + cost_vec = outArgs.getSlotValue(0) + cost_vec = cost_vec.copyToNumpyMat() + cost = cost_vec.sum() / len(data_batch) + print 'Batch id', batch_id, 'with cost=', cost + updater.finishBatch(cost) + updater.finishPass() m.finish() diff --git a/paddle/api/PaddleAPI.h b/paddle/api/PaddleAPI.h index 165997ba3499f4ee5225aec70ef3dfd0cf1a6d3e..cc49e6a09d5dee41ff47606025fcc492559aa958 100644 --- a/paddle/api/PaddleAPI.h +++ b/paddle/api/PaddleAPI.h @@ -799,7 +799,7 @@ public: void finishPass(); - PassType startBatch(int64_t batchSize); + PassType startBatch(size_t batchSize); void finishBatch(float cost); diff --git a/paddle/api/ParameterUpdater.cpp b/paddle/api/ParameterUpdater.cpp index e5d07b81782bf5212a043bea8d35c60a8b7ae4fa..fba47620249dbc7543678b3e7e969a21ff32647a 100644 --- a/paddle/api/ParameterUpdater.cpp +++ b/paddle/api/ParameterUpdater.cpp @@ -36,8 +36,8 @@ void ParameterUpdater::startPass() { m->updater->startPass(); } void ParameterUpdater::finishPass() { m->updater->finishPass(); } -PassType ParameterUpdater::startBatch(int64_t batchSize) { - return m->updater->startBatch(batchSize); +PassType ParameterUpdater::startBatch(size_t batchSize) { + return m->updater->startBatch((int64_t)batchSize); } void ParameterUpdater::finishBatch(float cost) {