提交 05ab22c3 编写于 作者: Y Yu Yang

A simplest train file for mnist added.

上级 20249e8e
...@@ -58,11 +58,25 @@ def main(): ...@@ -58,11 +58,25 @@ def main():
for _ in xrange(100): for _ in xrange(100):
updater.startPass() updater.startPass()
outArgs = api.Arguments.createArguments(0)
train_data_generator = input_order_converter( train_data_generator = input_order_converter(
read_from_mnist(train_file)) read_from_mnist(train_file))
for data_batch in generator_to_batch(train_data_generator, 128): for batch_id, data_batch in enumerate(
generator_to_batch(train_data_generator, 256)):
trainRole = updater.startBatch(len(data_batch)) trainRole = updater.startBatch(len(data_batch))
def update_callback(param):
updater.update(param)
m.forwardBackward(
converter(data_batch), outArgs, trainRole, update_callback)
cost_vec = outArgs.getSlotValue(0)
cost_vec = cost_vec.copyToNumpyMat()
cost = cost_vec.sum() / len(data_batch)
print 'Batch id', batch_id, 'with cost=', cost
updater.finishBatch(cost)
updater.finishPass() updater.finishPass()
m.finish() m.finish()
......
...@@ -799,7 +799,7 @@ public: ...@@ -799,7 +799,7 @@ public:
void finishPass(); void finishPass();
PassType startBatch(int64_t batchSize); PassType startBatch(size_t batchSize);
void finishBatch(float cost); void finishBatch(float cost);
......
...@@ -36,8 +36,8 @@ void ParameterUpdater::startPass() { m->updater->startPass(); } ...@@ -36,8 +36,8 @@ void ParameterUpdater::startPass() { m->updater->startPass(); }
void ParameterUpdater::finishPass() { m->updater->finishPass(); } void ParameterUpdater::finishPass() { m->updater->finishPass(); }
PassType ParameterUpdater::startBatch(int64_t batchSize) { PassType ParameterUpdater::startBatch(size_t batchSize) {
return m->updater->startBatch(batchSize); return m->updater->startBatch((int64_t)batchSize);
} }
void ParameterUpdater::finishBatch(float cost) { void ParameterUpdater::finishBatch(float cost) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册