未验证 提交 e5e810c3 编写于 作者: J Jeff Wang 提交者: GitHub

Merge pull request #542 from jetfuel/image_classification_new_api_markdown

Image classification new api markdown
......@@ -421,7 +421,7 @@ def load_image(file):
im = im / 255.0 * 2.0 - 1.0
return im
cur_dir = cur_dir = os.getcwd()
cur_dir = os.getcwd()
img = load_image(cur_dir + '/image/infer_3.png')
```
......
......@@ -463,7 +463,7 @@ def load_image(file):
im = im / 255.0 * 2.0 - 1.0
return im
cur_dir = cur_dir = os.getcwd()
cur_dir = os.getcwd()
img = load_image(cur_dir + '/image/infer_3.png')
```
......
此差异已折叠。
此差异已折叠。
......@@ -42,6 +42,10 @@ def train_network():
return [avg_cost, accuracy]
def optimizer_program():
return fluid.optimizer.Adam(learning_rate=0.001)
def train(use_cuda, train_program, params_dirname):
BATCH_SIZE = 128
EPOCH_NUM = 2
......@@ -56,7 +60,7 @@ def train(use_cuda, train_program, params_dirname):
def event_handler(event):
if isinstance(event, fluid.EndStepEvent):
if event.step % 100 == 0:
print("Pass %d, Batch %d, Cost %f, Acc %f" %
print("\nPass %d, Batch %d, Cost %f, Acc %f" %
(event.step, event.epoch, event.metrics[0],
event.metrics[1]))
else:
......@@ -67,15 +71,14 @@ def train(use_cuda, train_program, params_dirname):
avg_cost, accuracy = trainer.test(
reader=test_reader, feed_order=['pixel', 'label'])
print('Loss {0:2.2}, Acc {1:2.2}'.format(avg_cost, accuracy))
print('\nTest with Pass {0}, Loss {1:2.2}, Acc {2:2.2}'.format(
event.epoch, avg_cost, accuracy))
if params_dirname is not None:
trainer.save_params(params_dirname)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
trainer = fluid.Trainer(
train_func=train_program,
optimizer=fluid.optimizer.Adam(learning_rate=0.001),
place=place)
train_func=train_program, optimizer_func=optimizer_program, place=place)
trainer.train(
reader=train_reader,
......@@ -99,14 +102,10 @@ def infer(use_cuda, inference_program, params_dirname=None):
im = im.resize((32, 32), Image.ANTIALIAS)
im = np.array(im).astype(np.float32)
# The storage order of the loaded image is W(widht),
# The storage order of the loaded image is W(width),
# H(height), C(channel). PaddlePaddle requires
# the CHW order, so transpose them.
im = im.transpose((2, 0, 1)) # CHW
# In the training phase, the channel order of CIFAR
# image is B(Blue), G(green), R(Red). But PIL open
# image in RGB mode. It must swap the channel order.
im = im[(2, 1, 0), :, :] # BGR
im = im / 255.0
# Add one dimension to mimic the list format.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册