diff --git a/paddle/operators/math/selected_rows_functor.cu b/paddle/operators/math/selected_rows_functor.cu index c40649e55ef93dec852ff6949b5cb134495e4ebf..c1dd323ba29e03e3ab4a3e4d7248388b408fb9d6 100644 --- a/paddle/operators/math/selected_rows_functor.cu +++ b/paddle/operators/math/selected_rows_functor.cu @@ -227,7 +227,6 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; - } // namespace math } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py index f78d2f814c89aa6b5ee8387f2558a97c754e655c..c37fca8560cecbc66229f63cb7a6589178e80781 100644 --- a/python/paddle/v2/fluid/evaluator.py +++ b/python/paddle/v2/fluid/evaluator.py @@ -1,9 +1,14 @@ import numpy as np -from paddle.v2.fluid.framework import Program, g_main_program, unique_name, Variable -import paddle.v2.fluid.core as core +import paddle.v2.fluid.layers as layers +from paddle.v2.fluid.framework import Program, unique_name, \ + Variable +from paddle.v2.fluid.layer_helper import LayerHelper -def _clone_var_in_block_(block, var): +__all__ = ['Accuracy'] + + +def _clone_var_(block, var): assert isinstance(var, Variable) return block.create_var( name=var.name, @@ -16,175 +21,115 @@ def _clone_var_in_block_(block, var): class Evaluator(object): """ - Evalutor Base class. - - create metric states - add mini-batch evaluator caculate operator - add increment operator to accumulate the metric states + Base Class for all evaluators + + Args: + name(str): The name of evaluator. such as, "accuracy". Used for generate + temporary variable name. + main_program(Program, optional): The evaluator should be added to this + main_program. Default g_main_program + startup_program(Program, optional):The parameter should be added to this + startup_program. Default g_startup_program + + Attributes: + states(list): The list of state variables. states will be reset to zero + when `reset` is invoked. + metrics(list): The list of metrics variables. They will be calculate + every mini-batch """ def __init__(self, name, **kwargs): + self.states = [] + self.metrics = [] + self.helper = LayerHelper(name, **kwargs) + + def reset(self, executor, reset_program=None): """ - init the global states + reset metric states at the begin of each pass/user specified batch """ - self._states = {} - if kwargs.has_key("main_program"): - self._main_program = kwargs.get("main_program") - else: - self._main_program = g_main_program + if reset_program is None: + reset_program = Program() + + for var in self.states: + assert isinstance(var, Variable) + g_var = _clone_var_(reset_program.current_block(), var) + layers.fill_constant( + shape=g_var.shape, + value=0.0, + dtype=g_var.dtype, + out=g_var, + main_program=reset_program) - def states(self): - return self._states + executor.run(reset_program) - def _update_ops(self, *args, **kwargs): + def eval(self, executor, eval_program=None): """ - append update ops to the global states + Evaluate the statistics merged by multiple mini-batches. """ raise NotImplementedError() - def reset(self, executor, reset_program=None): + def create_state(self, suffix, dtype, shape): """ - Clear metric states at the begin of each pass/user specified batch - """ - if reset_program == None: - reset_program = Program() - else: - reset_program = program - block = reset_program.global_block() - for k, var in self._states.iteritems(): - g_var = _clone_var_in_block_(block, var) - zeros = block.create_var(dtype="float32", persistable=True) - block.append_op( - type="fill_constant", - outputs={"Out": [zeros]}, - attrs={ - "shape": g_var.shape, - "value": .0, - "dtype": 5, - }) - block.append_op( - type="scale", inputs={"X": zeros}, outputs={"Out": g_var}) - executor.run(reset_program, fetch_list=self._states.values()) + Create state variable. + + NOTE: It is not a public API. + + Args: + suffix(str): the state suffix. + dtype(str|core.DataType): the state data type + shape(tuple|list): the shape of state + + Returns: State variable - def eval(self, executor, eval_program=None): - """ - Merge the mini-batch statistics to form the evaluation result for multiple mini-batches. """ - raise NotImplementedError() + state = self.helper.create_variable( + name="_".join([unique_name(self.helper.name), suffix]), + persistable=True, + dtype=dtype, + shape=shape) + self.states.append(state) + return state class Accuracy(Evaluator): """ - Accuracy need two state variable Total, Correct + Average Accuracy for multiple mini-batches. """ - def __init__(self, *args, **kwargs): + def __init__(self, input, label, k=1, **kwargs): super(Accuracy, self).__init__("accuracy", **kwargs) - block = self._main_program.global_block() - g_total = block.create_var( - name=unique_name("Total"), - persistable=True, - dtype="int64", - shape=[1]) - g_correct = block.create_var( - name=unique_name("Correct"), - persistable=True, - dtype="int64", - shape=[1]) - self._states["Total"] = g_total - self._states["Correct"] = g_correct - - def _update_ops(self, input, label, k=1, **kwargs): - block = self._main_program.global_block() - topk_out = block.create_var(dtype=input.dtype) - topk_indices = block.create_var(dtype="int64") - block.append_op( - type="top_k", - inputs={"X": [input]}, - outputs={"Out": [topk_out], - "Indices": [topk_indices]}, - attrs={"k": k}) - acc_out = block.create_var(dtype=kwargs.get("out_dtype", "float32")) - correct = block.create_var(dtype="int64", persistable=True) - total = block.create_var(dtype="int64", persistable=True) - block.append_op( - type="accuracy", - inputs={ - "Out": [topk_out], - "Indices": [topk_indices], - "Label": [label] - }, - outputs={ - "Accuracy": [acc_out], - "Correct": [correct], - "Total": [total], - }) - - block.append_op( - type="cast", - inputs={"X": [self._states["Total"]]}, - outputs={"Out": [self._states["Total"]]}, - attrs={ - "in_dtype": 5, # float32 - "out_dtype": 2, # int32 - }) - block.append_op( - type="cast", - inputs={"X": [self._states["Correct"]]}, - outputs={"Out": [self._states["Correct"]]}, - attrs={ - "in_dtype": 5, - "out_dtype": 2, - }) - - block.append_op( - type="elementwise_add", - inputs={"X": [self._states["Total"]], - "Y": [total]}, - outputs={"Out": [self._states["Total"]]}) - block.append_op( - type="elementwise_add", - inputs={"X": [self._states["Correct"]], - "Y": [correct]}, - outputs={"Out": [self._states["Correct"]]}) - - return acc_out + main_program = self.helper.main_program + if main_program.current_block().idx != 0: + raise ValueError("You can only invoke Evaluator in root block") + + self.total = self.create_state(dtype='int64', shape=[1], suffix='total') + self.correct = self.create_state( + dtype='int64', shape=[1], suffix='correct') + kwargs = {'main_program': main_program} + total = self.helper.create_tmp_variable(dtype='int') + correct = self.helper.create_tmp_variable(dtype='int') + acc = layers.accuracy( + input=input, + label=label, + k=k, + total=total, + correct=correct, + **kwargs) + total = layers.cast(x=total, dtype='int64', **kwargs) + correct = layers.cast(x=correct, dtype='int64', **kwargs) + layers.sums(input=[self.total, total], out=self.total, **kwargs) + layers.sums(input=[self.correct, correct], out=self.correct, **kwargs) + + self.metrics.append(acc) def eval(self, executor, eval_program=None): - if eval_program != None: - eval_program = eval_program - else: + if eval_program is None: eval_program = Program() - block = eval_program.global_block() - eval_out = block.create_var(dtype=self._states["Total"].dtype) - e_total = _clone_var_in_block_(block, self._states["Total"]) - e_correct = _clone_var_in_block_(block, self._states["Correct"]) - block.append_op( - type="cast", - inputs={"X": [e_total]}, - outputs={"Out": [e_total]}, - attrs={ - "in_dtype": 2, # int32 - "out_dtype": 5, # float32 - }) - block.append_op( - type="cast", - inputs={"X": [e_correct]}, - outputs={"Out": [e_correct]}, - attrs={ - "in_dtype": 2, - "out_dtype": 5, - }) - block.append_op( - type="elementwise_div", - inputs={"X": e_correct, - "Y": e_total}, - outputs={"Out": eval_out}) - out = executor.run(eval_program, fetch_list=[eval_out]) - return np.array(out[0]) - - -def accuracy(*args, **kwargs): - cls = Accuracy(*args, **kwargs) - out = cls._update_ops(*args, **kwargs) - return cls, out + block = eval_program.current_block() + kwargs = {'main_program': eval_program} + total = _clone_var_(block, self.total) + correct = _clone_var_(block, self.correct) + total = layers.cast(total, dtype='float32', **kwargs) + correct = layers.cast(correct, dtype='float32', **kwargs) + out = layers.elementwise_div(x=correct, y=total, **kwargs) + return np.array(executor.run(eval_program, fetch_list=[out])[0]) diff --git a/python/paddle/v2/fluid/layers.py b/python/paddle/v2/fluid/layers.py index d094035fe5cae2e77fc2364e8ccb03c350f1301a..ca0c10e7009e7227c611e1a2f07acbc6eeb83ac5 100644 --- a/python/paddle/v2/fluid/layers.py +++ b/python/paddle/v2/fluid/layers.py @@ -418,6 +418,7 @@ def _create_op_func_(op_type): _create_op_func_('mean') _create_op_func_('mul') _create_op_func_('elementwise_add') +_create_op_func_('elementwise_div') _create_op_func_('dropout') _create_op_func_('reshape') _create_op_func_('sigmoid') @@ -457,13 +458,14 @@ def concat(input, axis, main_program=None, startup_program=None): return out -def sums(input, main_program=None, startup_program=None): +def sums(input, out=None, main_program=None, startup_program=None): """ This function takes in the input and performs the sum operation on it and returns that as the output. """ helper = LayerHelper('sum', **locals()) - out = helper.create_tmp_variable(dtype=helper.input_dtype()) + if out is None: + out = helper.create_tmp_variable(dtype=helper.input_dtype()) helper.append_op(type='sum', inputs={'X': input}, outputs={'Out': out}) return out @@ -606,7 +608,7 @@ def square_error_cost(input, label, **kwargs): return square_out -def accuracy(input, label, k=1, **kwargs): +def accuracy(input, label, k=1, correct=None, total=None, **kwargs): """ This function computes the accuracy using the input and label. The output is the top_k inputs and their indices. @@ -620,10 +622,11 @@ def accuracy(input, label, k=1, **kwargs): outputs={"Out": [topk_out], "Indices": [topk_indices]}, attrs={"k": k}) - acc_out_dtype = kwargs.get("out_dtype", "float32") acc_out = helper.create_tmp_variable(dtype="float32") - correct = helper.create_tmp_variable(dtype="int64") - total = helper.create_tmp_variable(dtype="int64") + if correct is None: + correct = helper.create_tmp_variable(dtype="int64") + if total is None: + total = helper.create_tmp_variable(dtype="int64") helper.append_op( type="accuracy", inputs={ @@ -1355,6 +1358,19 @@ def lod_rank_table(x, level=0, main_program=None): return table +def topk(input, k, main_program=None, startup_program=None): + helper = LayerHelper('topk', **locals()) + topk_out = helper.create_tmp_variable(dtype=input.data_type) + topk_indices = helper.create_tmp_variable(dtype='int64') + helper.append_op( + type='top_k', + inputs={'X': [input]}, + outputs={'Out': [topk_out], + 'Indices': [topk_indices]}, + attrs={'k': k}) + return topk_out, topk_indices + + def lod_tensor_to_array(x, table, main_program=None): """ This function creates an operator to convert an LOD_Tensor to @@ -1388,14 +1404,20 @@ def array_to_lod_tensor(x, table, main_program=None): return tmp -def fill_constant(shape, dtype, value, main_program=None, startup_program=None): +def fill_constant(shape, + dtype, + value, + out=None, + main_program=None, + startup_program=None): """ This function creates a tensor , with shape as mentioned in the input and specified dtype and fills this up with a constant value that comes in the input. It also sets the stop_gradient to be True. """ helper = LayerHelper("fill_constant", **locals()) - out = helper.create_tmp_variable(dtype=dtype) + if out is None: + out = helper.create_tmp_variable(dtype=dtype) helper.append_op( type='fill_constant', inputs={}, diff --git a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py index 76cbd410f94a4be04ba71d1e3175eaed590ac80a..b555b49ab228f018384dfa407fb34bcd39059e2e 100644 --- a/python/paddle/v2/fluid/tests/book/test_image_classification_train.py +++ b/python/paddle/v2/fluid/tests/book/test_image_classification_train.py @@ -5,7 +5,6 @@ import paddle.v2.fluid.framework as framework import paddle.v2.fluid.layers as layers import paddle.v2.fluid.nets as nets import paddle.v2.fluid.evaluator as evaluator -from paddle.v2.fluid.io import get_inference_program from paddle.v2.fluid.executor import Executor from paddle.v2.fluid.initializer import XavierInitializer from paddle.v2.fluid.optimizer import AdamOptimizer @@ -110,18 +109,16 @@ avg_cost = layers.mean(x=cost) optimizer = AdamOptimizer(learning_rate=0.001) opts = optimizer.minimize(avg_cost) -accuracy, acc_out = evaluator.accuracy(input=predict, label=label) +accuracy = evaluator.Accuracy(input=predict, label=label) BATCH_SIZE = 128 PASS_NUM = 1 train_reader = paddle.batch( paddle.reader.shuffle( - paddle.dataset.cifar.train10(), buf_size=BATCH_SIZE * 10), + paddle.dataset.cifar.train10(), buf_size=128 * 10), batch_size=BATCH_SIZE) -test_reader = paddle.batch(paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE) - place = core.CPUPlace() exe = Executor(place) @@ -147,46 +144,15 @@ for pass_id in range(PASS_NUM): outs = exe.run(framework.default_main_program(), feed={"pixel": tensor_img, "label": tensor_y}, - fetch_list=[avg_cost, acc_out]) + fetch_list=[avg_cost] + accuracy.metrics) loss = np.array(outs[0]) acc = np.array(outs[1]) pass_acc = accuracy.eval(exe) - - batch_id = batch_id + 1 - - test_accuracy, test_acc_out = evaluator.accuracy( - input=predict, label=label) - - test_target = [avg_cost, test_acc_out] + test_accuracy.states().values() - inference_program = get_inference_program(test_target) - - test_accuracy.reset(exe) - - for data in test_reader(): - x_data = np.array(map(lambda x: x[0].reshape(data_shape), - data)).astype("float32") - y_data = np.array(map(lambda x: x[1], data)).astype("int64") - y_data = np.expand_dims(y_data, axis=1) - - tensor_x = core.LoDTensor() - tensor_x.set(x_data, place) - - tensor_y = core.LoDTensor() - tensor_y.set(y_data, place) - - outs = exe.run(inference_program, - feed={'pixel': tensor_x, - 'label': tensor_y}, - fetch_list=[avg_cost, test_acc_out]) - out = np.array(outs[0]) - acc = np.array(outs[1]) - - test_pass_acc = test_accuracy.eval(exe) - print("pass_id:" + str(pass_id) + " batch_id:" + str(batch_id) + " loss:" + str(loss) + " acc:" + str(acc) + " pass_acc:" + str( - pass_acc) + " test_pass_acc:" + str(test_pass_acc)) + pass_acc)) + batch_id = batch_id + 1 if batch_id > 1: # this model is slow, so if we can train two mini batch, we think it works properly. diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py index 0bea5f95c895b278db86f25f54e2795d3ec0af69..97f1f12724b7d6a7cddcd6d405065da7303574e4 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_conv.py @@ -31,7 +31,7 @@ avg_cost = layers.mean(x=cost) optimizer = AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999) opts = optimizer.minimize(avg_cost) -accuracy, acc_out = evaluator.accuracy(input=predict, label=label) +accuracy = evaluator.Accuracy(input=predict, label=label) BATCH_SIZE = 50 PASS_NUM = 3 @@ -61,7 +61,7 @@ for pass_id in range(PASS_NUM): outs = exe.run(framework.default_main_program(), feed={"pixel": tensor_img, "label": tensor_y}, - fetch_list=[avg_cost, acc_out]) + fetch_list=[avg_cost] + accuracy.metrics) loss = np.array(outs[0]) acc = np.array(outs[1]) pass_acc = accuracy.eval(exe) diff --git a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py index f57a5c8d98cd8b89e1d300b4d1fe00d6b24b0d68..7dbb34f5da66de774ee84a4cdb028b10e326da4d 100644 --- a/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py +++ b/python/paddle/v2/fluid/tests/book/test_recognize_digits_mlp.py @@ -36,7 +36,7 @@ avg_cost = layers.mean(x=cost) optimizer = MomentumOptimizer(learning_rate=0.001, momentum=0.9) opts = optimizer.minimize(avg_cost) -accuracy, acc_out = evaluator.accuracy(input=predict, label=label) +accuracy = evaluator.Accuracy(input=predict, label=label) train_reader = paddle.batch( paddle.reader.shuffle( @@ -67,15 +67,14 @@ for pass_id in range(PASS_NUM): outs = exe.run(framework.default_main_program(), feed={'x': tensor_x, 'y': tensor_y}, - fetch_list=[avg_cost, acc_out]) + fetch_list=[avg_cost] + accuracy.metrics) out = np.array(outs[0]) acc = np.array(outs[1]) pass_acc = accuracy.eval(exe) - test_accuracy, test_acc_out = evaluator.accuracy( - input=predict, label=label) + test_accuracy = evaluator.Accuracy(input=predict, label=label) - test_target = [avg_cost, test_acc_out] + test_accuracy.states().values() + test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states inference_program = get_inference_program(test_target) test_accuracy.reset(exe) @@ -93,7 +92,7 @@ for pass_id in range(PASS_NUM): outs = exe.run(inference_program, feed={'x': tensor_x, 'y': tensor_y}, - fetch_list=[avg_cost, test_acc_out]) + fetch_list=[avg_cost] + test_accuracy.metrics) out = np.array(outs[0]) acc = np.array(outs[1]) diff --git a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py index 3103be83a63d64fcba87132ddc5d830b92047b27..054cdb324c7cd91f9c9e51a54b50d2880e21b05d 100644 --- a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_conv.py @@ -32,9 +32,9 @@ def convolution_net(input_dim, class_dim=2, emb_dim=32, hid_dim=32): cost = layers.cross_entropy(input=prediction, label=label) avg_cost = layers.mean(x=cost) adam_optimizer = AdamOptimizer(learning_rate=0.002) - opts = adam_optimizer.minimize(avg_cost) - accuracy, acc_out = evaluator.accuracy(input=prediction, label=label) - return avg_cost, accuracy, acc_out + adam_optimizer.minimize(avg_cost) + accuracy = evaluator.Accuracy(input=prediction, label=label) + return avg_cost, accuracy, accuracy.metrics[0] def to_lodtensor(data, place): diff --git a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py index 208978224f4e83a23efadae37fbe51d0d59dafe8..854ef82614a9959ce97a149035cd9a682d534a72 100644 --- a/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py +++ b/python/paddle/v2/fluid/tests/book/test_understand_sentiment_dynamic_lstm.py @@ -41,9 +41,9 @@ def stacked_lstm_net(input_dim, cost = layers.cross_entropy(input=prediction, label=label) avg_cost = layers.mean(x=cost) adam_optimizer = AdamOptimizer(learning_rate=0.002) - opts = adam_optimizer.minimize(avg_cost) - accuracy, acc_out = evaluator.accuracy(input=prediction, label=label) - return avg_cost, accuracy, acc_out + adam_optimizer.minimize(avg_cost) + accuracy = evaluator.Accuracy(input=prediction, label=label) + return avg_cost, accuracy, accuracy.metrics[0] def to_lodtensor(data, place):