diff --git a/benchmark/cluster/vgg16/vgg16_fluid.py b/benchmark/cluster/vgg16/vgg16_fluid.py index 53e71998f1517a8f2cfa4509426231fb0f0177e8..786f224608f7d41c438411de0e09fedbcf2264b8 100644 --- a/benchmark/cluster/vgg16/vgg16_fluid.py +++ b/benchmark/cluster/vgg16/vgg16_fluid.py @@ -1,11 +1,11 @@ # Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -138,13 +138,14 @@ def main(): avg_cost = fluid.layers.mean(x=cost) # Evaluator - accuracy = fluid.evaluator.Accuracy(input=predict, label=label) + batch_size = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size) # inference program inference_program = fluid.default_main_program().clone() with fluid.program_guard(inference_program): - test_target = accuracy.metrics + accuracy.states - inference_program = fluid.io.get_inference_program(test_target) + inference_program = fluid.io.get_inference_program(batch_acc) # Optimization optimizer = fluid.optimizer.Adam(learning_rate=args.learning_rate) @@ -157,27 +158,30 @@ def main(): # test def test(exe): - accuracy.reset(exe) + test_pass_acc = fluid.average.WeightedAverage() for batch_id, data in enumerate(test_reader()): img_data = np.array(map(lambda x: x[0].reshape(data_shape), data)).astype("float32") y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = y_data.reshape([-1, 1]) - exe.run(inference_program, - feed={"pixel": img_data, - "label": y_data}) + outs = exe.run(inference_program, + feed={"pixel": img_data, + "label": y_data}, + fetch_list=[batch_acc, batch_size]) + test_pass_acc.add(value=np.array(outs[0]), weight=np.array(outs[1])) - return accuracy.eval(exe) + return test_pass_acc.eval() def train_loop(exe, trainer_prog): iters = 0 ts = time.time() + train_pass_acc = fluid.average.WeightedAverage() for pass_id in range(args.num_passes): # train start_time = time.time() num_samples = 0 - accuracy.reset(exe) + train_pass_acc.reset() with profiler.profiler("CPU", 'total') as prof: for batch_id, data in enumerate(train_reader()): ts = time.time() @@ -187,13 +191,14 @@ def main(): y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = y_data.reshape([-1, 1]) - loss, acc = exe.run( + loss, acc, b_size = exe.run( trainer_prog, feed={"pixel": img_data, "label": y_data}, - fetch_list=[avg_cost] + accuracy.metrics) + fetch_list=[avg_cost, batch_acc, batch_size]) iters += 1 num_samples += len(data) + train_pass_acc.add(value=acc, weight=b_size) print( "Pass = %d, Iters = %d, Loss = %f, Accuracy = %f, Speed = %.2f img/s" % (pass_id, iters, loss, acc, @@ -201,7 +206,7 @@ def main(): ) # The accuracy is the accumulation of batches, but not the current batch. pass_elapsed = time.time() - start_time - pass_train_acc = accuracy.eval(exe) + pass_train_acc = train_pass_acc.eval() pass_test_acc = test(exe) print( "Pass = %d, Training performance = %f imgs/s, Train accuracy = %f, Test accuracy = %f\n" diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index 3f407d05768f707507e4a1339a64f3d7ae4506a9..0df3fd0343dbdaee88192a402d990fdfc2235811 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -28,6 +28,7 @@ import nets import optimizer import backward import regularizer +import average from param_attr import ParamAttr, WeightNormParamAttr from data_feeder import DataFeeder from core import LoDTensor, CPUPlace, CUDAPlace diff --git a/python/paddle/fluid/average.py b/python/paddle/fluid/average.py new file mode 100644 index 0000000000000000000000000000000000000000..ded6eb085968343fcdc9f6e4b8353c08408df426 --- /dev/null +++ b/python/paddle/fluid/average.py @@ -0,0 +1,61 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +""" + Class of all kinds of Average. + + All Averages are accomplished via Python totally. + They do not change Paddle's Program, nor do anything to + modify NN model's configuration. They are completely + wrappers of Python functions. +""" + + +def _is_number_(var): + return isinstance(var, int) or isinstance(var, float) or (isinstance( + var, np.ndarray) and var.shape == (1, )) + + +def _is_number_or_matrix_(var): + return _is_number_(var) or isinstance(var, np.ndarray) + + +class WeightedAverage(object): + def __init__(self): + self.reset() + + def reset(self): + self.numerator = None + self.denominator = None + + def add(self, value, weight): + if not _is_number_or_matrix_(value): + raise ValueError( + "The 'value' must be a number(int, float) or a numpy ndarray.") + if not _is_number_(weight): + raise ValueError("The 'weight' must be a number(int, float).") + + if self.numerator is None or self.denominator is None: + self.numerator = value * weight + self.denominator = weight + else: + self.numerator += value * weight + self.denominator += weight + + def eval(self): + if self.numerator is None or self.denominator is None: + raise ValueError( + "There is no data to be averaged in WeightedAverage.") + return self.numerator / self.denominator diff --git a/python/paddle/fluid/evaluator.py b/python/paddle/fluid/evaluator.py index 18b1cdce8ad88d1423d69772773fcfa78b11c592..19e5b61b0b32aba3fe1e7805704a3740e3854fc8 100644 --- a/python/paddle/fluid/evaluator.py +++ b/python/paddle/fluid/evaluator.py @@ -108,44 +108,6 @@ class Evaluator(object): return state -class Accuracy(Evaluator): - """ - Average Accuracy for multiple mini-batches. - """ - - def __init__(self, input, label, k=1, **kwargs): - super(Accuracy, self).__init__("accuracy", **kwargs) - main_program = self.helper.main_program - if main_program.current_block().idx != 0: - raise ValueError("You can only invoke Evaluator in root block") - - self.total = self.create_state(dtype='int64', shape=[1], suffix='total') - self.correct = self.create_state( - dtype='int64', shape=[1], suffix='correct') - total = self.helper.create_tmp_variable(dtype='int') - correct = self.helper.create_tmp_variable(dtype='int') - acc = layers.accuracy( - input=input, label=label, k=k, total=total, correct=correct) - total = layers.cast(x=total, dtype='int64') - correct = layers.cast(x=correct, dtype='int64') - layers.sums(input=[self.total, total], out=self.total) - layers.sums(input=[self.correct, correct], out=self.correct) - - self.metrics.append(acc) - - def eval(self, executor, eval_program=None): - if eval_program is None: - eval_program = Program() - block = eval_program.current_block() - with program_guard(main_program=eval_program): - total = _clone_var_(block, self.total) - correct = _clone_var_(block, self.correct) - total = layers.cast(total, dtype='float32') - correct = layers.cast(correct, dtype='float32') - out = layers.elementwise_div(x=correct, y=total) - return np.array(executor.run(eval_program, fetch_list=[out])[0]) - - class ChunkEvaluator(Evaluator): """ Accumulate counter numbers output by chunk_eval from mini-batches and diff --git a/python/paddle/fluid/layers/__init__.py b/python/paddle/fluid/layers/__init__.py index 14d33582f41a33da49b1e5176b2094a6a81b3dac..a568f61dcb2da976baa7847ae26281a34d6f88dd 100644 --- a/python/paddle/fluid/layers/__init__.py +++ b/python/paddle/fluid/layers/__init__.py @@ -28,6 +28,8 @@ import math_op_patch from math_op_patch import * import detection from detection import * +import metric +from metric import * from learning_rate_scheduler import * __all__ = [] @@ -39,4 +41,5 @@ __all__ += control_flow.__all__ __all__ += ops.__all__ __all__ += device.__all__ __all__ += detection.__all__ +__all__ += metric.__all__ __all__ += learning_rate_scheduler.__all__ diff --git a/python/paddle/fluid/layers/metric.py b/python/paddle/fluid/layers/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..3d9157ad4ef9381b70b4007c5bdca91f1482b427 --- /dev/null +++ b/python/paddle/fluid/layers/metric.py @@ -0,0 +1,57 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +All layers just related to metric. +""" + +from ..layer_helper import LayerHelper +from ..initializer import Normal, Constant +from ..framework import Variable +from ..param_attr import ParamAttr + +__all__ = ['accuracy'] + + +def accuracy(input, label, k=1, correct=None, total=None): + """ + This function computes the accuracy using the input and label. + The output is the top_k inputs and their indices. + """ + helper = LayerHelper("accuracy", **locals()) + topk_out = helper.create_tmp_variable(dtype=input.dtype) + topk_indices = helper.create_tmp_variable(dtype="int64") + helper.append_op( + type="top_k", + inputs={"X": [input]}, + outputs={"Out": [topk_out], + "Indices": [topk_indices]}, + attrs={"k": k}) + acc_out = helper.create_tmp_variable(dtype="float32") + if correct is None: + correct = helper.create_tmp_variable(dtype="int64") + if total is None: + total = helper.create_tmp_variable(dtype="int64") + helper.append_op( + type="accuracy", + inputs={ + "Out": [topk_out], + "Indices": [topk_indices], + "Label": [label] + }, + outputs={ + "Accuracy": [acc_out], + "Correct": [correct], + "Total": [total], + }) + return acc_out diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index ffa477ba9b88126d8ff0ed404e64830b087314e9..a0842c57ee5f807c046b36f51f8236897ceb641b 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -35,7 +35,6 @@ __all__ = [ 'cos_sim', 'cross_entropy', 'square_error_cost', - 'accuracy', 'chunk_eval', 'sequence_conv', 'conv2d', @@ -1022,40 +1021,6 @@ def square_error_cost(input, label): return square_out -def accuracy(input, label, k=1, correct=None, total=None): - """ - This function computes the accuracy using the input and label. - The output is the top_k inputs and their indices. - """ - helper = LayerHelper("accuracy", **locals()) - topk_out = helper.create_tmp_variable(dtype=input.dtype) - topk_indices = helper.create_tmp_variable(dtype="int64") - helper.append_op( - type="top_k", - inputs={"X": [input]}, - outputs={"Out": [topk_out], - "Indices": [topk_indices]}, - attrs={"k": k}) - acc_out = helper.create_tmp_variable(dtype="float32") - if correct is None: - correct = helper.create_tmp_variable(dtype="int64") - if total is None: - total = helper.create_tmp_variable(dtype="int64") - helper.append_op( - type="accuracy", - inputs={ - "Out": [topk_out], - "Indices": [topk_indices], - "Label": [label] - }, - outputs={ - "Accuracy": [acc_out], - "Correct": [correct], - "Total": [total], - }) - return acc_out - - def chunk_eval(input, label, chunk_scheme, @@ -3182,7 +3147,7 @@ def smooth_l1(x, y, inside_weight=None, outside_weight=None, sigma=None): data = fluid.layers.data(name='data', shape=[128], dtype='float32') label = fluid.layers.data(name='label', shape=[100], dtype='int64') fc = fluid.layers.fc(input=data, size=100) - out = fluid.layers.smooth_l1(logits=fc, label=label) + out = fluid.layers.smooth_l1(x=fc, y=label) """ helper = LayerHelper('smooth_l1_loss', **locals()) diff = helper.create_tmp_variable(dtype=x.dtype) diff --git a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py index f900da815b204cd7472b93607b816054999d7d69..80ff11f8d78b0a22fc6aefd722c9e6a2c23fbd5c 100644 --- a/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py +++ b/python/paddle/fluid/tests/book_memory_optimization/test_memopt_image_classification_train.py @@ -122,7 +122,8 @@ avg_cost = fluid.layers.mean(cost) optimizer = fluid.optimizer.Adam(learning_rate=0.001) opts = optimizer.minimize(avg_cost) -accuracy = fluid.evaluator.Accuracy(input=predict, label=label) +batch_size = fluid.layers.create_tensor(dtype='int64') +batch_acc = fluid.layers.accuracy(input=predict, label=label, total=batch_size) fluid.memory_optimize(fluid.default_main_program()) @@ -144,13 +145,17 @@ feeder = fluid.DataFeeder(place=place, feed_list=[images, label]) exe.run(fluid.default_startup_program()) i = 0 + +accuracy = fluid.average.WeightedAverage() for pass_id in range(PASS_NUM): - accuracy.reset(exe) + accuracy.reset() for data in train_reader(): - loss, acc = exe.run(fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[avg_cost] + accuracy.metrics) - pass_acc = accuracy.eval(exe) + loss, acc, weight = exe.run( + fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost, batch_acc, batch_size]) + accuracy.add(value=acc, weight=weight) + pass_acc = accuracy.eval() print("loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( pass_acc)) # this model is slow, so if we can train two mini batch, we think it works properly. diff --git a/python/paddle/fluid/tests/unittests/test_profiler.py b/python/paddle/fluid/tests/unittests/test_profiler.py index 20b38ecfd7477b7368a1d3bfae3c8dacb06afd39..1da6b94eea30e65913ce713e0e5e355507534161 100644 --- a/python/paddle/fluid/tests/unittests/test_profiler.py +++ b/python/paddle/fluid/tests/unittests/test_profiler.py @@ -37,7 +37,9 @@ class TestProfiler(unittest.TestCase): label = fluid.layers.data(name='y', shape=[1], dtype='int64') cost = fluid.layers.cross_entropy(input=predict, label=label) avg_cost = fluid.layers.mean(cost) - accuracy = fluid.evaluator.Accuracy(input=predict, label=label) + batch_size = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size) optimizer = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9) opts = optimizer.minimize(avg_cost, startup_program=startup_program) @@ -46,7 +48,7 @@ class TestProfiler(unittest.TestCase): exe = fluid.Executor(place) exe.run(startup_program) - accuracy.reset(exe) + pass_acc_calculator = fluid.average.WeightedAverage() with profiler.profiler(state, 'total', profile_path) as prof: for iter in range(10): if iter == 2: @@ -57,9 +59,11 @@ class TestProfiler(unittest.TestCase): outs = exe.run(main_program, feed={'x': x, 'y': y}, - fetch_list=[avg_cost] + accuracy.metrics) + fetch_list=[avg_cost, batch_acc, batch_size]) acc = np.array(outs[1]) - pass_acc = accuracy.eval(exe) + b_size = np.array(outs[2]) + pass_acc_calculator.add(value=acc, weight=b_size) + pass_acc = pass_acc_calculator.eval() def test_cpu_profiler(self): self.net_profiler('CPU')