提交 35e770f6 编写于 作者: G gongweibao

modify to paddle.batch

上级 785f9d44
...@@ -166,7 +166,7 @@ def event_handler(event): ...@@ -166,7 +166,7 @@ def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
result = trainer.test( result = trainer.test(
reader=paddle.reader.batched( reader=paddle.batch(
uci_housing.test(), batch_size=2), uci_housing.test(), batch_size=2),
reader_dict=reader_dict) reader_dict=reader_dict)
print "Test %d, Cost %f" % (event.pass_id, result.cost) print "Test %d, Cost %f" % (event.pass_id, result.cost)
...@@ -176,7 +176,7 @@ def event_handler(event): ...@@ -176,7 +176,7 @@ def event_handler(event):
```python ```python
# training # training
trainer.train( trainer.train(
reader=paddle.reader.batched( reader=paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
uci_housing.train(), buf_size=500), uci_housing.train(), buf_size=500),
batch_size=2), batch_size=2),
......
...@@ -8,9 +8,7 @@ def main(): ...@@ -8,9 +8,7 @@ def main():
# network config # network config
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13)) x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x, y_predict = paddle.layer.fc(input=x, size=1, act=paddle.activation.Linear())
size=1,
act=paddle.activation.Linear())
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1)) y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.regression_cost(input=y_predict, label=y) cost = paddle.layer.regression_cost(input=y_predict, label=y)
...@@ -35,14 +33,14 @@ def main(): ...@@ -35,14 +33,14 @@ def main():
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
result = trainer.test( result = trainer.test(
reader=paddle.reader.batched( reader=paddle.batch(
uci_housing.test(), batch_size=2), uci_housing.test(), batch_size=2),
reader_dict=reader_dict) reader_dict=reader_dict)
print "Test %d, Cost %f" % (event.pass_id, result.cost) print "Test %d, Cost %f" % (event.pass_id, result.cost)
# training # training
trainer.train( trainer.train(
reader=paddle.reader.batched( reader=paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
uci_housing.train(), buf_size=500), uci_housing.train(), buf_size=500),
batch_size=2), batch_size=2),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册