提交 05ab22c3 编写于 作者: Y Yu Yang

A simplest train file for mnist added.

上级 20249e8e
......@@ -58,11 +58,25 @@ def main():
for _ in xrange(100):
updater.startPass()
outArgs = api.Arguments.createArguments(0)
train_data_generator = input_order_converter(
read_from_mnist(train_file))
for data_batch in generator_to_batch(train_data_generator, 128):
for batch_id, data_batch in enumerate(
generator_to_batch(train_data_generator, 256)):
trainRole = updater.startBatch(len(data_batch))
def update_callback(param):
updater.update(param)
m.forwardBackward(
converter(data_batch), outArgs, trainRole, update_callback)
cost_vec = outArgs.getSlotValue(0)
cost_vec = cost_vec.copyToNumpyMat()
cost = cost_vec.sum() / len(data_batch)
print 'Batch id', batch_id, 'with cost=', cost
updater.finishBatch(cost)
updater.finishPass()
m.finish()
......
......@@ -799,7 +799,7 @@ public:
void finishPass();
PassType startBatch(int64_t batchSize);
PassType startBatch(size_t batchSize);
void finishBatch(float cost);
......
......@@ -36,8 +36,8 @@ void ParameterUpdater::startPass() { m->updater->startPass(); }
void ParameterUpdater::finishPass() { m->updater->finishPass(); }
PassType ParameterUpdater::startBatch(int64_t batchSize) {
return m->updater->startBatch(batchSize);
PassType ParameterUpdater::startBatch(size_t batchSize) {
return m->updater->startBatch((int64_t)batchSize);
}
void ParameterUpdater::finishBatch(float cost) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册