From 78038e65a96cd8197b37134ecda69d8ac395ab60 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Sat, 10 Mar 2018 15:26:26 +0800 Subject: [PATCH] update mobilenet.py --- fluid/image_classification/mobilenet.py | 32 +++++++++++++------------ 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/fluid/image_classification/mobilenet.py b/fluid/image_classification/mobilenet.py index 48d266c0..28656013 100644 --- a/fluid/image_classification/mobilenet.py +++ b/fluid/image_classification/mobilenet.py @@ -172,13 +172,13 @@ def train(learning_rate, batch_size, num_passes, model_save_dir='model'): momentum=0.9, regularization=fluid.regularizer.L2Decay(5 * 1e-5)) opts = optimizer.minimize(avg_cost) - accuracy = fluid.evaluator.Accuracy(input=out, label=label) + + b_size = fluid.layers.create_tensor(dtype='int64') + b_acc = fluid.layers.accuracy(input=out, label=label, total=b_size) inference_program = fluid.default_main_program().clone() with fluid.program_guard(inference_program): - test_accuracy = fluid.evaluator.Accuracy(input=out, label=label) - test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states - inference_program = fluid.io.get_inference_program(test_target) + inference_program = fluid.io.get_inference_program(b_acc) place = fluid.CUDAPlace(0) exe = fluid.Executor(place) @@ -190,24 +190,26 @@ def train(learning_rate, batch_size, num_passes, model_save_dir='model'): paddle.dataset.flowers.test(), batch_size=batch_size) feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) + train_pass_acc = fluid.average.WeightedAverage() + test_pass_acc = fluid.average.WeightedAverage() for pass_id in range(num_passes): - accuracy.reset(exe) + train_pass_acc.reset() for batch_id, data in enumerate(train_reader()): - loss, acc = exe.run(fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[avg_cost] + accuracy.metrics) + loss, acc, size = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost, b_acc, b_size]) + train_pass_acc.add(value=acc, weight=size) print("Pass {0}, batch {1}, loss {2}, acc {3}".format( pass_id, batch_id, loss[0], acc[0])) - pass_acc = accuracy.eval(exe) - test_accuracy.reset(exe) + test_pass_acc.reset() for data in test_reader(): - loss, acc = exe.run(inference_program, - feed=feeder.feed(data), - fetch_list=[avg_cost] + test_accuracy.metrics) - test_pass_acc = test_accuracy.eval(exe) + loss, acc, size = exe.run(inference_program, + feed=feeder.feed(data), + fetch_list=[avg_cost, b_acc, b_size]) + test_pass_acc.add(value=acc, weight=size) print("End pass {0}, train_acc {1}, test_acc {2}".format( - pass_id, pass_acc, test_pass_acc)) + pass_id, train_pass_acc.eval(), test_pass_acc.eval())) if pass_id % 10 == 0: model_path = os.path.join(model_save_dir, str(pass_id)) print 'save models to %s' % (model_path) -- GitLab