metric.py 20.2 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
All layers just related to metric.
"""
17 18
import numpy as np

19
import paddle
20 21
from paddle import _legacy_C_ops
from paddle.fluid.data_feeder import check_variable_and_dtype
22
from paddle.fluid.framework import Variable, _create_tensor, _non_static_mode
23
from paddle.fluid.layer_helper import LayerHelper
24
from paddle.nn.initializer import ConstantInitializer
F
fengjiayi 已提交
25

26
__all__ = []
F
fengjiayi 已提交
27 28 29 30


def accuracy(input, label, k=1, correct=None, total=None):
    """
U
ustiniankw 已提交
31

D
dzhwinter 已提交
32 33
    accuracy layer.
    Refer to the https://en.wikipedia.org/wiki/Precision_and_recall
F
fengjiayi 已提交
34
    This function computes the accuracy using the input and label.
D
dzhwinter 已提交
35
    If the correct label occurs in top k predictions, then correct will increment by one.
U
ustiniankw 已提交
36 37 38 39

    Note:
        the dtype of accuracy is determined by input. the input and label dtype can be different.

D
dzhwinter 已提交
40
    Args:
41
        input(Tensor): The input of accuracy layer, which is the predictions of network. A Tensor with type float32,float64.
42
            The shape is ``[sample_number, class_dim]`` .
43
        label(Tensor): The label of dataset.  Tensor with type int32,int64. The shape is ``[sample_number, 1]`` .
U
ustiniankw 已提交
44 45 46 47
        k(int, optional): The top k predictions for each class will be checked. Data type is int64 or int32. Default is 1.
        correct(Tensor, optional): The correct predictions count. A Tensor with type int64 or int32. Default is None.
        total(Tensor, optional): The total entries count. A tensor with type int64 or int32. Default is None.

D
dzhwinter 已提交
48
    Returns:
U
ustiniankw 已提交
49 50
        Tensor, The correct rate. A Tensor with type float32.

D
dzhwinter 已提交
51 52
    Examples:
        .. code-block:: python
U
ustiniankw 已提交
53

54
            import numpy as np
J
Jiaqi Liu 已提交
55 56 57 58 59 60 61 62 63 64 65 66
            import paddle
            import paddle.static as static
            import paddle.nn.functional as F
            paddle.enable_static()
            data = static.data(name="input", shape=[-1, 32, 32], dtype="float32")
            label = static.data(name="label", shape=[-1,1], dtype="int")
            fc_out = static.nn.fc(x=data, size=10)
            predict = F.softmax(x=fc_out)
            result = static.accuracy(input=predict, label=label, k=5)
            place = paddle.CPUPlace()
            exe = static.Executor(place)
            exe.run(static.default_startup_program())
67 68
            x = np.random.rand(3, 32, 32).astype("float32")
            y = np.array([[1],[0],[1]])
69 70
            output = exe.run(feed={"input": x,"label": y},
                             fetch_list=[result])
71
            print(output)
72
            # [array(0.33333334, dtype=float32)]
U
ustiniankw 已提交
73

F
fengjiayi 已提交
74
    """
J
Jiabin Yang 已提交
75
    if _non_static_mode():
76
        if correct is None:
77
            correct = _create_tensor(dtype="int32")
78
        if total is None:
79
            total = _create_tensor(dtype="int32")
80

81
        _k = np.array(k).item(0) if isinstance(k, Variable) else k
82 83 84 85 86 87
        topk_out, topk_indices = _legacy_C_ops.top_k_v2(
            input, 'k', _k, 'sorted', False
        )
        _acc, _, _ = _legacy_C_ops.accuracy(
            topk_out, topk_indices, label, correct, total
        )
88
        return _acc
89

F
fengjiayi 已提交
90
    helper = LayerHelper("accuracy", **locals())
91
    check_variable_and_dtype(
92
        input, 'input', ['float16', 'uint16', 'float32', 'float64'], 'accuracy'
93
    )
94 95 96 97 98 99 100 101
    topk_out = helper.create_variable_for_type_inference(dtype=input.dtype)
    topk_indices = helper.create_variable_for_type_inference(dtype="int64")
    inputs = {"X": [input]}
    if isinstance(k, Variable):
        inputs['K'] = [k]
    else:
        attrs = {'k': k}
    attrs['sorted'] = False
102 103 104 105 106 107
    helper.append_op(
        type="top_k_v2",
        inputs=inputs,
        attrs=attrs,
        outputs={"Out": [topk_out], "Indices": [topk_indices]},
    )
X
Xin Pan 已提交
108
    acc_out = helper.create_variable_for_type_inference(dtype="float32")
F
fengjiayi 已提交
109
    if correct is None:
110
        correct = helper.create_variable_for_type_inference(dtype="int32")
F
fengjiayi 已提交
111
    if total is None:
112
        total = helper.create_variable_for_type_inference(dtype="int32")
113 114 115 116 117 118 119 120 121
    helper.append_op(
        type="accuracy",
        inputs={"Out": [topk_out], "Indices": [topk_indices], "Label": [label]},
        outputs={
            "Accuracy": [acc_out],
            "Correct": [correct],
            "Total": [total],
        },
    )
F
fengjiayi 已提交
122
    return acc_out
D
dzhwinter 已提交
123 124


125 126 127 128 129 130 131 132 133
def auc(
    input,
    label,
    curve='ROC',
    num_thresholds=2**12 - 1,
    topk=1,
    slide_steps=1,
    ins_tag_weight=None,
):
134
    """
Y
Yibing Liu 已提交
135
    **Area Under the Curve (AUC) Layer**
Y
Yibing Liu 已提交
136 137

    This implementation computes the AUC according to forward output and label.
138
    It is used very widely in binary classification evaluation.
Y
Yibing Liu 已提交
139

140
    Note: If input label contains values other than 0 and 1, it will be cast
Y
Yibing Liu 已提交
141 142
    to `bool`. Find the relevant definitions `here <https://en.wikipedia.org\
    /wiki/Receiver_operating_characteristic#Area_under_the_curve>`_.
Y
Yibing Liu 已提交
143 144

    There are two types of possible curves:
Y
Yibing Liu 已提交
145 146 147

        1. ROC: Receiver operating characteristic;
        2. PR: Precision Recall
Y
Yibing Liu 已提交
148 149

    Args:
150
        input(Tensor): A floating-point 2D Tensor, values are in the range
151 152
                         [0, 1]. Each row is sorted in descending order. This
                         input should be the output of topk. Typically, this
153 154 155
                         Tensor indicates the probability of each label.
                         A Tensor with type float32,float64.
        label(Tensor): A 2D int Tensor indicating the label of the training
Y
Yibing Liu 已提交
156
                         data. The height is batch size and width is always 1.
157
                         A Tensor with type int32,int64.
L
LoneRanger 已提交
158 159
        curve(str, optional): Curve type, can be 'ROC' or 'PR'. Default 'ROC'.
        num_thresholds(int, optional): The number of thresholds to use when discretizing
160
                             the roc curve. Default 4095.
L
LoneRanger 已提交
161 162 163
        topk(int, optional): only topk number of prediction output will be used for auc.
        slide_steps(int, optional): when calc batch auc, we can not only use step currently but the previous steps can be used. slide_steps=1 means use the current step, slide_steps=3 means use current step and the previous second steps, slide_steps=0 use all of the steps.
        ins_tag_weight(Tensor, optional): A 2D int Tensor indicating the data's tag weight, 1 means real data, 0 means fake data. Default None, and it will be assigned to a tensor of value 1.
164
                         A Tensor with type float32,float64.
Y
Yibing Liu 已提交
165 166

    Returns:
L
LoneRanger 已提交
167 168 169 170 171 172 173 174 175 176
        Tensor: A tuple representing the current AUC. Data type is Tensor, supporting float32, float64.
        The return tuple is auc_out, batch_auc_out, [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg ]

            auc_out: the result of the accuracy rate
            batch_auc_out: the result of the batch accuracy
            batch_stat_pos: the statistic value for label=1 at the time of batch calculation
            batch_stat_neg: the statistic value for label=0 at the time of batch calculation
            stat_pos: the statistic for label=1 at the time of calculation
            stat_neg: the statistic for label=0 at the time of calculation

Y
Yibing Liu 已提交
177

178
    Examples:
Y
Yibing Liu 已提交
179
        .. code-block:: python
180

181
            import paddle
182
            import numpy as np
183
            paddle.enable_static()
184

185 186
            data = paddle.static.data(name="input", shape=[-1, 32,32], dtype="float32")
            label = paddle.static.data(name="label", shape=[-1], dtype="int")
187 188
            fc_out = paddle.static.nn.fc(x=data, size=2)
            predict = paddle.nn.functional.softmax(x=fc_out)
189 190 191 192 193 194 195 196 197 198 199
            result=paddle.static.auc(input=predict, label=label)

            place = paddle.CPUPlace()
            exe = paddle.static.Executor(place)

            exe.run(paddle.static.default_startup_program())
            x = np.random.rand(3,32,32).astype("float32")
            y = np.array([1,0,1])
            output= exe.run(feed={"input": x,"label": y},
                             fetch_list=[result[0]])
            print(output)
200

201 202
            #you can learn the usage of ins_tag_weight by the following code.
            '''
203 204
            import paddle
            import numpy as np
J
Jiaqi Liu 已提交
205
            paddle.enable_static()
206 207 208

            data = paddle.static.data(name="input", shape=[-1, 32,32], dtype="float32")
            label = paddle.static.data(name="label", shape=[-1], dtype="int")
209 210 211
            ins_tag_weight = paddle.static.data(name='ins_tag', shape=[-1,16], lod_level=0, dtype='float64')
            fc_out = paddle.static.nn.fc(x=data, size=2)
            predict = paddle.nn.functional.softmax(x=fc_out)
212
            result=paddle.static.auc(input=predict, label=label, ins_tag_weight=ins_tag_weight)
213

J
Jiaqi Liu 已提交
214
            place = paddle.CPUPlace()
215
            exe = paddle.static.Executor(place)
216

217
            exe.run(paddle.static.default_startup_program())
218 219
            x = np.random.rand(3,32,32).astype("float32")
            y = np.array([1,0,1])
220 221
            z = np.array([1,0,1])
            output= exe.run(feed={"input": x,"label": y, "ins_tag_weight":z},
222
                             fetch_list=[result[0]])
223
            print(output)
224 225
            '''

Y
Yibing Liu 已提交
226
    """
D
dzhwinter 已提交
227
    helper = LayerHelper("auc", **locals())
228 229

    if ins_tag_weight is None:
230
        ins_tag_weight = paddle.tensor.fill_constant(
231 232
            shape=[1, 1], dtype="float32", value=1.0
        )
233 234
    check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'auc')
    check_variable_and_dtype(label, 'label', ['int32', 'int64'], 'auc')
235 236 237
    check_variable_and_dtype(
        ins_tag_weight, 'ins_tag_weight', ['float32', 'float64'], 'auc'
    )
X
Xin Pan 已提交
238 239
    auc_out = helper.create_variable_for_type_inference(dtype="float64")
    batch_auc_out = helper.create_variable_for_type_inference(dtype="float64")
W
Wu Yi 已提交
240
    # make tp, tn, fp, fn persistable, so that can accumulate all batches.
T
tangwei12 已提交
241 242

    # for batch auc
243 244
    # we create slide_step+1 buckets, the first slide_steps buckets store
    # historical batch-level values, and the last bucket stores the sum values of
245 246 247
    # previous slide_step buckets.
    # The index of bucket that the newest batch will use is determined by batch_id mod slide_steps,
    # and batch_id is store in the last posision of following variable
T
tangwei12 已提交
248 249 250
    batch_stat_pos = helper.create_global_variable(
        persistable=True,
        dtype='int64',
251 252
        shape=[(1 + slide_steps) * (num_thresholds + 1) + 1],
    )
T
tangwei12 已提交
253 254 255
    batch_stat_neg = helper.create_global_variable(
        persistable=True,
        dtype='int64',
256 257
        shape=[(1 + slide_steps) * (num_thresholds + 1) + 1],
    )
T
tangwei12 已提交
258 259

    # for global auc
260
    # Needn't maintain the batch id
261 262 263 264 265 266
    stat_pos = helper.create_global_variable(
        persistable=True, dtype='int64', shape=[1, num_thresholds + 1]
    )
    stat_neg = helper.create_global_variable(
        persistable=True, dtype='int64', shape=[1, num_thresholds + 1]
    )
T
tangwei12 已提交
267

T
tangwei12 已提交
268
    for var in [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg]:
269
        helper.set_variable_initializer(
270 271
            var,
            ConstantInitializer(value=0.0, force_cpu=False),
272
        )
W
Wu Yi 已提交
273

274
    # "InsTagWeight": [ins_tag_weight]
T
tangwei12 已提交
275
    # Batch AUC
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
    helper.append_op(
        type="auc",
        inputs={
            "Predict": [input],
            "Label": [label],
            "StatPos": [batch_stat_pos],
            "StatNeg": [batch_stat_neg],
        },
        attrs={
            "curve": curve,
            "num_thresholds": num_thresholds,
            "slide_steps": slide_steps,
        },
        outputs={
            "AUC": [batch_auc_out],
            "StatPosOut": [batch_stat_pos],
            "StatNegOut": [batch_stat_neg],
        },
    )
T
tangwei12 已提交
295
    # Global AUC
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
    helper.append_op(
        type="auc",
        inputs={
            "Predict": [input],
            "Label": [label],
            "StatPos": [stat_pos],
            "StatNeg": [stat_neg],
        },
        attrs={
            "curve": curve,
            "num_thresholds": num_thresholds,
            "slide_steps": 0,
        },
        outputs={
            "AUC": [auc_out],
            "StatPosOut": [stat_pos],
            "StatNegOut": [stat_neg],
        },
    )
    return (
        auc_out,
        batch_auc_out,
        [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg],
    )
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569


def ctr_metric_bundle(input, label, ins_tag_weight=None):
    """
    ctr related metric layer

    This function help compute the ctr related metrics: RMSE, MAE, predicted_ctr, q_value.
    To compute the final values of these metrics, we should do following computations using
    total instance number:
    MAE = local_abserr / instance number
    RMSE = sqrt(local_sqrerr / instance number)
    predicted_ctr = local_prob / instance number
    q = local_q / instance number
    Note that if you are doing distribute job, you should all reduce these metrics and instance
    number first

    Args:
        input(Tensor): A floating-point 2D Tensor, values are in the range
                         [0, 1]. Each row is sorted in descending order. This
                         input should be the output of topk. Typically, this
                         Tensor indicates the probability of each label.
        label(Tensor): A 2D int Tensor indicating the label of the training
                         data. The height is batch size and width is always 1.
        ins_tag_weight(Tensor): A 2D int Tensor indicating the ins_tag_weight of the training
                         data. 1 means real data, 0 means fake data.
                         A LoDTensor or Tensor with type float32,float64.

    Returns:
        local_sqrerr(Tensor): Local sum of squared error
        local_abserr(Tensor): Local sum of abs error
        local_prob(Tensor): Local sum of predicted ctr
        local_q(Tensor): Local sum of q value

    Examples 1:
        .. code-block:: python

            import paddle
            paddle.enable_static()
            data = paddle.static.data(name="data", shape=[32, 32], dtype="float32")
            label = paddle.static.data(name="label", shape=[-1, 1], dtype="int32")
            predict = paddle.nn.functional.sigmoid(paddle.static.nn.fc(input=data, size=1))
            auc_out = paddle.static.ctr_metric_bundle(input=predict, label=label)
    Examples 2:
        .. code-block:: python

            import paddle
            paddle.enable_static()
            data = paddle.static.data(name="data", shape=[32, 32], dtype="float32")
            label = paddle.static.data(name="label", shape=[-1, 1], dtype="int32")
            predict = paddle.nn.functional.sigmoid(paddle.static.nn.fc(input=data, size=1))
            ins_tag_weight = paddle.static.data(name='ins_tag', shape=[-1,16], lod_level=0, dtype='int64')
            auc_out = paddle.static.ctr_metric_bundle(input=predict, label=label, ins_tag_weight=ins_tag_weight)

    """
    if ins_tag_weight is None:
        ins_tag_weight = paddle.tensor.fill_constant(
            shape=[1, 1], dtype="float32", value=1.0
        )

    assert input.shape == label.shape
    helper = LayerHelper("ctr_metric_bundle", **locals())

    local_abserr = helper.create_global_variable(
        persistable=True, dtype='float32', shape=[1]
    )
    local_sqrerr = helper.create_global_variable(
        persistable=True, dtype='float32', shape=[1]
    )
    local_prob = helper.create_global_variable(
        persistable=True, dtype='float32', shape=[1]
    )
    local_q = helper.create_global_variable(
        persistable=True, dtype='float32', shape=[1]
    )
    local_pos_num = helper.create_global_variable(
        persistable=True, dtype='float32', shape=[1]
    )
    local_ins_num = helper.create_global_variable(
        persistable=True, dtype='float32', shape=[1]
    )

    tmp_res_elesub = helper.create_global_variable(
        persistable=False, dtype='float32', shape=[-1]
    )
    tmp_res_sigmoid = helper.create_global_variable(
        persistable=False, dtype='float32', shape=[-1]
    )
    tmp_ones = helper.create_global_variable(
        persistable=False, dtype='float32', shape=[-1]
    )

    batch_prob = helper.create_global_variable(
        persistable=False, dtype='float32', shape=[1]
    )
    batch_abserr = helper.create_global_variable(
        persistable=False, dtype='float32', shape=[1]
    )
    batch_sqrerr = helper.create_global_variable(
        persistable=False, dtype='float32', shape=[1]
    )
    batch_q = helper.create_global_variable(
        persistable=False, dtype='float32', shape=[1]
    )
    batch_pos_num = helper.create_global_variable(
        persistable=False, dtype='float32', shape=[1]
    )
    batch_ins_num = helper.create_global_variable(
        persistable=False, dtype='float32', shape=[1]
    )
    for var in [
        local_abserr,
        batch_abserr,
        local_sqrerr,
        batch_sqrerr,
        local_prob,
        batch_prob,
        local_q,
        batch_q,
        batch_pos_num,
        batch_ins_num,
        local_pos_num,
        local_ins_num,
    ]:
        helper.set_variable_initializer(
            var,
            paddle.nn.initializer.ConstantInitializer(
                value=0.0, force_cpu=True
            ),
        )

    helper.append_op(
        type="elementwise_sub",
        inputs={"X": [input], "Y": [label]},
        outputs={"Out": [tmp_res_elesub]},
    )

    helper.append_op(
        type="squared_l2_norm",
        inputs={"X": [tmp_res_elesub]},
        outputs={"Out": [batch_sqrerr]},
    )
    helper.append_op(
        type="elementwise_add",
        inputs={"X": [batch_sqrerr], "Y": [local_sqrerr]},
        outputs={"Out": [local_sqrerr]},
    )

    helper.append_op(
        type="l1_norm",
        inputs={"X": [tmp_res_elesub]},
        outputs={"Out": [batch_abserr]},
    )
    helper.append_op(
        type="elementwise_add",
        inputs={"X": [batch_abserr], "Y": [local_abserr]},
        outputs={"Out": [local_abserr]},
    )

    helper.append_op(
        type="reduce_sum", inputs={"X": [input]}, outputs={"Out": [batch_prob]}
    )
    helper.append_op(
        type="elementwise_add",
        inputs={"X": [batch_prob], "Y": [local_prob]},
        outputs={"Out": [local_prob]},
    )
    helper.append_op(
        type="sigmoid",
        inputs={"X": [input]},
        outputs={"Out": [tmp_res_sigmoid]},
    )
    helper.append_op(
        type="reduce_sum",
        inputs={"X": [tmp_res_sigmoid]},
        outputs={"Out": [batch_q]},
    )

    helper.append_op(
        type="reduce_sum",
        inputs={"X": [label]},
        outputs={"Out": [batch_pos_num]},
    )
    helper.append_op(
        type="elementwise_add",
        inputs={"X": [batch_pos_num], "Y": [local_pos_num]},
        outputs={"Out": [local_pos_num]},
    )

    helper.append_op(
        type='fill_constant_batch_size_like',
        inputs={"Input": label},
        outputs={'Out': [tmp_ones]},
        attrs={
            'shape': [-1, 1],
            'dtype': tmp_ones.dtype,
            'value': float(1.0),
        },
    )
    helper.append_op(
        type="reduce_sum",
        inputs={"X": [tmp_ones]},
        outputs={"Out": [batch_ins_num]},
    )

    # if data is fake, return 0
    inputs_slice = {'Input': ins_tag_weight}
    attrs = {'axes': [0]}
    attrs['starts'] = [0]
    attrs['ends'] = [1]
    helper.append_op(
        type="slice",
        inputs=inputs_slice,
        attrs=attrs,
        outputs={"Out": ins_tag_weight},
    )

    axis = helper.kwargs.get('axis', 0)
    helper.append_op(
        type="elementwise_mul",
        inputs={"X": [batch_ins_num], "Y": [ins_tag_weight]},
        outputs={"Out": [batch_ins_num]},
        attrs={'axis': axis},
    )

    helper.append_op(
        type="elementwise_add",
        inputs={"X": [batch_ins_num], "Y": [local_ins_num]},
        outputs={"Out": [local_ins_num]},
    )

    helper.append_op(
        type="elementwise_mul",
        inputs={"X": [batch_q], "Y": [ins_tag_weight]},
        outputs={"Out": [batch_q]},
        attrs={'axis': axis},
    )
    helper.append_op(
        type="elementwise_add",
        inputs={"X": [batch_q], "Y": [local_q]},
        outputs={"Out": [local_q]},
    )

    return (
        local_sqrerr,
        local_abserr,
        local_prob,
        local_q,
        local_pos_num,
        local_ins_num,
    )