diff --git a/python/paddle/v2/fluid/layers/__init__.py b/python/paddle/v2/fluid/layers/__init__.py index 906a16a49f728526a41d1bc6da3a40e30bbfa33f..c0f4a7f8745fcd5f45aa5ebac90494426792089e 100644 --- a/python/paddle/v2/fluid/layers/__init__.py +++ b/python/paddle/v2/fluid/layers/__init__.py @@ -28,6 +28,8 @@ import math_op_patch from math_op_patch import * import detection from detection import * +import metric +from metric import * __all__ = [] __all__ += math_op_patch.__all__ @@ -38,3 +40,4 @@ __all__ += control_flow.__all__ __all__ += ops.__all__ __all__ += device.__all__ __all__ += detection.__all__ +__all__ += metric.__all__ diff --git a/python/paddle/v2/fluid/layers/metric.py b/python/paddle/v2/fluid/layers/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..3d9157ad4ef9381b70b4007c5bdca91f1482b427 --- /dev/null +++ b/python/paddle/v2/fluid/layers/metric.py @@ -0,0 +1,57 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +All layers just related to metric. +""" + +from ..layer_helper import LayerHelper +from ..initializer import Normal, Constant +from ..framework import Variable +from ..param_attr import ParamAttr + +__all__ = ['accuracy'] + + +def accuracy(input, label, k=1, correct=None, total=None): + """ + This function computes the accuracy using the input and label. + The output is the top_k inputs and their indices. + """ + helper = LayerHelper("accuracy", **locals()) + topk_out = helper.create_tmp_variable(dtype=input.dtype) + topk_indices = helper.create_tmp_variable(dtype="int64") + helper.append_op( + type="top_k", + inputs={"X": [input]}, + outputs={"Out": [topk_out], + "Indices": [topk_indices]}, + attrs={"k": k}) + acc_out = helper.create_tmp_variable(dtype="float32") + if correct is None: + correct = helper.create_tmp_variable(dtype="int64") + if total is None: + total = helper.create_tmp_variable(dtype="int64") + helper.append_op( + type="accuracy", + inputs={ + "Out": [topk_out], + "Indices": [topk_indices], + "Label": [label] + }, + outputs={ + "Accuracy": [acc_out], + "Correct": [correct], + "Total": [total], + }) + return acc_out diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index e8b4cec6ee638b839e2a7c38e032f74b9cd738ef..f090288aa8db2f3fc4f8f757b013281b8c83aba7 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -34,7 +34,6 @@ __all__ = [ 'cos_sim', 'cross_entropy', 'square_error_cost', - 'accuracy', 'chunk_eval', 'sequence_conv', 'conv2d', @@ -1020,40 +1019,6 @@ def square_error_cost(input, label): return square_out -def accuracy(input, label, k=1, correct=None, total=None): - """ - This function computes the accuracy using the input and label. - The output is the top_k inputs and their indices. - """ - helper = LayerHelper("accuracy", **locals()) - topk_out = helper.create_tmp_variable(dtype=input.dtype) - topk_indices = helper.create_tmp_variable(dtype="int64") - helper.append_op( - type="top_k", - inputs={"X": [input]}, - outputs={"Out": [topk_out], - "Indices": [topk_indices]}, - attrs={"k": k}) - acc_out = helper.create_tmp_variable(dtype="float32") - if correct is None: - correct = helper.create_tmp_variable(dtype="int64") - if total is None: - total = helper.create_tmp_variable(dtype="int64") - helper.append_op( - type="accuracy", - inputs={ - "Out": [topk_out], - "Indices": [topk_indices], - "Label": [label] - }, - outputs={ - "Accuracy": [acc_out], - "Correct": [correct], - "Total": [total], - }) - return acc_out - - def chunk_eval(input, label, chunk_scheme,