提交 43d92fa8 编写于 作者: Y Yu Yang

Make api_train_v2 runnable

上级 7293c821
...@@ -20,7 +20,7 @@ def main(): ...@@ -20,7 +20,7 @@ def main():
adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01) adam_optimizer = paddle.optimizer.Adam(learning_rate=0.01)
trainer = paddle.trainer.SGD(topology=cost, trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters, parameters=parameters,
update_equation=adam_optimizer) update_equation=adam_optimizer)
...@@ -28,7 +28,7 @@ def main(): ...@@ -28,7 +28,7 @@ def main():
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 1000 == 0: if event.batch_id % 1000 == 0:
result = trainer.test(reader=paddle.reader.batched( result = trainer.test(reader=paddle.reader.batched(
paddle.dataset.mnist.test_creator(), batch_size=256)) paddle.dataset.mnist.test(), batch_size=256))
print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % ( print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics, event.pass_id, event.batch_id, event.cost, event.metrics,
......
...@@ -9,9 +9,9 @@ __all__ = ['train', 'test'] ...@@ -9,9 +9,9 @@ __all__ = ['train', 'test']
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/' URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz' TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5 = '25e3cc63507ef6e98d5dc541e8672bb6' TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz' TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
TEST_LABEL_MD5 = '4e9511fe019b2189026bd0421ba7b688' TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz' TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873' TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz' TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
......
...@@ -61,7 +61,7 @@ class SGD(ITrainer): ...@@ -61,7 +61,7 @@ class SGD(ITrainer):
self.__topology__ = topology self.__topology__ = topology
self.__parameters__ = parameters self.__parameters__ = parameters
self.__topology_in_proto__ = topology.proto() self.__topology_in_proto__ = topology.proto()
self.__data_types__ = topology.data_layers() self.__data_types__ = topology.data_type()
gm = api.GradientMachine.createFromConfigProto( gm = api.GradientMachine.createFromConfigProto(
self.__topology_in_proto__, api.CREATE_MODE_NORMAL, self.__topology_in_proto__, api.CREATE_MODE_NORMAL,
self.__optimizer__.enable_types()) self.__optimizer__.enable_types())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册