提交 43d92fa8 编写于 作者: Y Yu Yang

Make api_train_v2 runnable

上级 7293c821
......@@ -20,7 +20,7 @@ def main():
adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
trainer = paddle.trainer.SGD(topology=cost,
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=adam_optimizer)
......@@ -28,7 +28,7 @@ def main():
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 1000 == 0:
result = trainer.test(reader=paddle.reader.batched(
paddle.dataset.mnist.test_creator(), batch_size=256))
paddle.dataset.mnist.test(), batch_size=256))
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics,
......
......@@ -9,9 +9,9 @@ __all__ = ['train', 'test']
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6'
TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
TEST_LABEL_MD5 = '4e9511fe019b2189026bd0421ba7b688'
TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
......
......@@ -61,7 +61,7 @@ class SGD(ITrainer):
self.__topology__ = topology
self.__parameters__ = parameters
self.__topology_in_proto__ = topology.proto()
self.__data_types__ = topology.data_layers()
self.__data_types__ = topology.data_type()
gm = api.GradientMachine.createFromConfigProto(
self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
self.__optimizer__.enable_types())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册