提交 78038e65 编写于 作者: F fengjiayi

update mobilenet.py

上级 df8060e7
......@@ -172,13 +172,13 @@ def train(learning_rate, batch_size, num_passes, model_save_dir='model'):
momentum=0.9,
regularization=fluid.regularizer.L2Decay(5 * 1e-5))
opts = optimizer.minimize(avg_cost)
accuracy = fluid.evaluator.Accuracy(input=out, label=label)
b_size = fluid.layers.create_tensor(dtype='int64')
b_acc = fluid.layers.accuracy(input=out, label=label, total=b_size)
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
test_accuracy = fluid.evaluator.Accuracy(input=out, label=label)
test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states
inference_program = fluid.io.get_inference_program(test_target)
inference_program = fluid.io.get_inference_program(b_acc)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
......@@ -190,24 +190,26 @@ def train(learning_rate, batch_size, num_passes, model_save_dir='model'):
paddle.dataset.flowers.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
train_pass_acc = fluid.average.WeightedAverage()
test_pass_acc = fluid.average.WeightedAverage()
for pass_id in range(num_passes):
accuracy.reset(exe)
train_pass_acc.reset()
for batch_id, data in enumerate(train_reader()):
loss, acc = exe.run(fluid.default_main_program(),
loss, acc, size = exe.run(fluid.default_main_program(),
feed=feeder.feed(data),
fetch_list=[avg_cost] + accuracy.metrics)
fetch_list=[avg_cost, b_acc, b_size])
train_pass_acc.add(value=acc, weight=size)
print("Pass {0}, batch {1}, loss {2}, acc {3}".format(
pass_id, batch_id, loss[0], acc[0]))
pass_acc = accuracy.eval(exe)
test_accuracy.reset(exe)
test_pass_acc.reset()
for data in test_reader():
loss, acc = exe.run(inference_program,
loss, acc, size = exe.run(inference_program,
feed=feeder.feed(data),
fetch_list=[avg_cost] + test_accuracy.metrics)
test_pass_acc = test_accuracy.eval(exe)
fetch_list=[avg_cost, b_acc, b_size])
test_pass_acc.add(value=acc, weight=size)
print("End pass {0}, train_acc {1}, test_acc {2}".format(
pass_id, pass_acc, test_pass_acc))
pass_id, train_pass_acc.eval(), test_pass_acc.eval()))
if pass_id % 10 == 0:
model_path = os.path.join(model_save_dir, str(pass_id))
print 'save models to %s' % (model_path)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册