From f8029403a06bffb87134e8e7f724d5ee45dfd89b Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 5 Mar 2018 13:24:06 +0800 Subject: [PATCH] remove Evaluator.Accuracy --- benchmark/cluster/vgg16/vgg16_fluid.py | 35 +++++++++-------- python/paddle/fluid/__init__.py | 1 + python/paddle/{v2 => }/fluid/average.py | 0 python/paddle/fluid/evaluator.py | 38 ------------------- python/paddle/{v2 => }/fluid/layers/metric.py | 0 .../test_memopt_image_classification_train.py | 17 ++++++--- .../fluid/tests/unittests/test_profiler.py | 12 ++++-- 7 files changed, 40 insertions(+), 63 deletions(-) rename python/paddle/{v2 => }/fluid/average.py (100%) rename python/paddle/{v2 => }/fluid/layers/metric.py (100%) diff --git a/benchmark/cluster/vgg16/vgg16_fluid.py b/benchmark/cluster/vgg16/vgg16_fluid.py index 7323241f4d3..80eee112ddf 100644 --- a/benchmark/cluster/vgg16/vgg16_fluid.py +++ b/benchmark/cluster/vgg16/vgg16_fluid.py @@ -1,11 +1,11 @@ # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -138,13 +138,14 @@ def main(): avg_cost = fluid.layers.mean(x=cost) # Evaluator - accuracy = fluid.evaluator.Accuracy(input=predict, label=label) + batch_size = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size) # inference program inference_program = fluid.default_main_program().clone() with fluid.program_guard(inference_program): - test_target = accuracy.metrics + accuracy.states - inference_program = fluid.io.get_inference_program(test_target) + inference_program = fluid.io.get_inference_program(batch_acc) # Optimization optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate) @@ -157,27 +158,30 @@ def main(): # test def test(exe): - accuracy.reset(exe) + test_pass_acc = fluid.average.WeightedAverage() for batch_id, data in enumerate(test_reader()): img_data = np.array(map(lambda x: x[0].reshape(data_shape), data)).astype("float32") y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = y_data.reshape([-1, 1]) - exe.run(inference_program, - feed={"pixel": img_data, - "label": y_data}) + outs = exe.run(inference_program, + feed={"pixel": img_data, + "label": y_data}, + fetch_list=[batch_acc, batch_size]) + test_pass_acc.add(value=np.array(outs[0]), weight=np.array(outs[1])) - return accuracy.eval(exe) + return test_pass_acc.eval() def train_loop(exe, trainer_prog): iters = 0 ts = time.time() + train_pass_acc = fluid.average.WeightedAverage() for pass_id in range(args.num_passes): # train start_time = time.time() num_samples = 0 - accuracy.reset(exe) + train_pass_acc.reset() with profiler.profiler("CPU", 'total') as prof: for batch_id, data in enumerate(train_reader()): ts = time.time() @@ -187,13 +191,14 @@ def main(): y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = y_data.reshape([-1, 1]) - loss, acc = exe.run( + loss, acc, b_size = exe.run( trainer_prog, feed={"pixel": img_data, "label": y_data}, - fetch_list=[avg_cost] + accuracy.metrics) + fetch_list=[avg_cost, batch_acc, batch_size]) iters += 1 num_samples += len(data) + train_pass_acc.add(value=acc, weight=b_size) print( "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s" % (pass_id, iters, loss, acc, @@ -201,7 +206,7 @@ def main(): ) # The accuracy is the accumulation of batches, but not the current batch. pass_elapsed = time.time() - start_time - pass_train_acc = accuracy.eval(exe) + pass_train_acc = train_pass_acc.eval() pass_test_acc = test(exe) print( "Pass = %d, Training performance = %f imgs/s, Train accuracy = %f, Test accuracy = %f\n" diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 39d13d3ab5f..2afb3f2f649 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -29,6 +29,7 @@ import optimizer import learning_rate_decay import backward import regularizer +import average from param_attr import ParamAttr, WeightNormParamAttr from data_feeder import DataFeeder from core import LoDTensor, CPUPlace, CUDAPlace diff --git a/python/paddle/v2/fluid/average.py b/python/paddle/fluid/average.py similarity index 100% rename from python/paddle/v2/fluid/average.py rename to python/paddle/fluid/average.py diff --git a/python/paddle/fluid/evaluator.py b/python/paddle/fluid/evaluator.py index 8cc49053337..d8caecb3fc4 100644 --- a/python/paddle/fluid/evaluator.py +++ b/python/paddle/fluid/evaluator.py @@ -105,44 +105,6 @@ class Evaluator(object): return state -class Accuracy(Evaluator): - """ - Average Accuracy for multiple mini-batches. - """ - - def __init__(self, input, label, k=1, **kwargs): - super(Accuracy, self).__init__("accuracy", **kwargs) - main_program = self.helper.main_program - if main_program.current_block().idx != 0: - raise ValueError("You can only invoke Evaluator in root block") - - self.total = self.create_state(dtype='int64', shape=[1], suffix='total') - self.correct = self.create_state( - dtype='int64', shape=[1], suffix='correct') - total = self.helper.create_tmp_variable(dtype='int') - correct = self.helper.create_tmp_variable(dtype='int') - acc = layers.accuracy( - input=input, label=label, k=k, total=total, correct=correct) - total = layers.cast(x=total, dtype='int64') - correct = layers.cast(x=correct, dtype='int64') - layers.sums(input=[self.total, total], out=self.total) - layers.sums(input=[self.correct, correct], out=self.correct) - - self.metrics.append(acc) - - def eval(self, executor, eval_program=None): - if eval_program is None: - eval_program = Program() - block = eval_program.current_block() - with program_guard(main_program=eval_program): - total = _clone_var_(block, self.total) - correct = _clone_var_(block, self.correct) - total = layers.cast(total, dtype='float32') - correct = layers.cast(correct, dtype='float32') - out = layers.elementwise_div(x=correct, y=total) - return np.array(executor.run(eval_program, fetch_list=[out])[0]) - - class ChunkEvaluator(Evaluator): """ Accumulate counter numbers output by chunk_eval from mini-batches and diff --git a/python/paddle/v2/fluid/layers/metric.py b/python/paddle/fluid/layers/metric.py similarity index 100% rename from python/paddle/v2/fluid/layers/metric.py rename to python/paddle/fluid/layers/metric.py diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py index 57202cea1aa..a3e0893f2df 100644 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py +++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py @@ -122,7 +122,8 @@ avg_cost = fluid.layers.mean(cost) optimizer = fluid.optimizer.Adam(learning_rate=0.001) opts = optimizer.minimize(avg_cost) -accuracy = fluid.evaluator.Accuracy(input=predict, label=label) +batch_size = fluid.layers.create_tensor(dtype='int64') +batch_acc = fluid.layers.accuracy(input=predict, label=label, total=batch_size) fluid.memory_optimize(fluid.default_main_program()) @@ -144,13 +145,17 @@ feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) exe.run(fluid.default_startup_program()) i = 0 + +accuracy = fluid.average.WeightedAverage() for pass_id in range(PASS_NUM): - accuracy.reset(exe) + accuracy.reset() for data in train_reader(): - loss, acc = exe.run(fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[avg_cost] + accuracy.metrics) - pass_acc = accuracy.eval(exe) + loss, acc, weight = exe.run( + fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost, batch_acc, batch_size]) + accuracy.add(value=acc, weight=weight) + pass_acc = accuracy.eval() print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( pass_acc)) # this model is slow, so if we can train two mini batch, we think it works properly. diff --git a/python/paddle/fluid/tests/unittests/test_profiler.py b/python/paddle/fluid/tests/unittests/test_profiler.py index f6f581ff7d6..b2ce655cad0 100644 --- a/python/paddle/fluid/tests/unittests/test_profiler.py +++ b/python/paddle/fluid/tests/unittests/test_profiler.py @@ -37,7 +37,9 @@ class TestProfiler(unittest.TestCase): label = fluid.layers.data(name='y', shape=[1], dtype='int64') cost = fluid.layers.cross_entropy(input=predict, label=label) avg_cost = fluid.layers.mean(cost) - accuracy = fluid.evaluator.Accuracy(input=predict, label=label) + batch_size = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size) optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9) opts = optimizer.minimize(avg_cost, startup_program=startup_program) @@ -46,7 +48,7 @@ class TestProfiler(unittest.TestCase): exe = fluid.Executor(place) exe.run(startup_program) - accuracy.reset(exe) + pass_acc_calculator = fluid.average.WeightedAverage() with profiler.profiler(state, 'total') as prof: for iter in range(10): if iter == 2: @@ -57,9 +59,11 @@ class TestProfiler(unittest.TestCase): outs = exe.run(main_program, feed={'x': x, 'y': y}, - fetch_list=[avg_cost] + accuracy.metrics) + fetch_list=[avg_cost, batch_acc, batch_size]) acc = np.array(outs[1]) - pass_acc = accuracy.eval(exe) + b_size = np.array(outs[2]) + pass_acc_calculator.add(value=acc, weight=b_size) + pass_acc = pass_acc_calculator.eval() def test_cpu_profiler(self): self.net_profiler('CPU') -- GitLab