未验证 提交 1ae49ef4 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #1090 from luotao1/fix_image_classification_error

fix image_classification error
...@@ -160,5 +160,5 @@ def val(file_list=TEST_LIST): ...@@ -160,5 +160,5 @@ def val(file_list=TEST_LIST):
return _reader_creator(file_list, 'val', shuffle=False) return _reader_creator(file_list, 'val', shuffle=False)
def test(file_list): def test(file_list=TEST_LIST):
return _reader_creator(file_list, 'test', shuffle=False) return _reader_creator(file_list, 'test', shuffle=False)
...@@ -157,7 +157,8 @@ def train(args): ...@@ -157,7 +157,8 @@ def train(args):
test_reader = paddle.batch(reader.val(), batch_size=test_batch_size) test_reader = paddle.batch(reader.val(), batch_size=test_batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name) train_exe = fluid.ParallelExecutor(
use_cuda=True if args.use_gpu else False, loss_name=avg_cost.name)
fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name] fetch_list = [avg_cost.name, acc_top1.name, acc_top5.name]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册