未验证 提交 728d5b3a 编写于 作者: W wangzhen38 提交者: GitHub

[DOC FIX]fix code demo of auc (#45200)

* [DOC FIX]fix code demo of auc

* [doc fix] fix doc of auc
上级 78916a7a
......@@ -164,7 +164,7 @@ def auc(input,
batch_stat_pos, batch_stat_neg, stat_pos, stat_neg ]
Data type is Tensor, supporting float32, float64.
Examples 1:
Examples:
.. code-block:: python
import paddle
......@@ -173,8 +173,8 @@ def auc(input,
data = paddle.static.data(name="input", shape=[-1, 32,32], dtype="float32")
label = paddle.static.data(name="label", shape=[-1], dtype="int")
fc_out = paddle.static.nn.fc(input=data, size=2)
predict = paddle.nn.functional.softmax(input=fc_out)
fc_out = paddle.static.nn.fc(x=data, size=2)
predict = paddle.nn.functional.softmax(x=fc_out)
result=paddle.static.auc(input=predict, label=label)
place = paddle.CPUPlace()
......@@ -186,19 +186,18 @@ def auc(input,
output= exe.run(feed={"input": x,"label": y},
fetch_list=[result[0]])
print(output)
#[array([0.5])]
Examples 2:
.. code-block:: python
#you can learn the usage of ins_tag_weight by the following code.
'''
import paddle
import numpy as np
paddle.enable_static()
data = paddle.static.data(name="input", shape=[-1, 32,32], dtype="float32")
label = paddle.static.data(name="label", shape=[-1], dtype="int")
fc_out = paddle.static.nn.fc(input=data, size=2)
predict = paddle.nn.functional.softmax(input=fc_out)
ins_tag_weight = paddle.static.data(name='ins_tag', shape=[-1,16], lod_level=0, dtype='int64')
ins_tag_weight = paddle.static.data(name='ins_tag', shape=[-1,16], lod_level=0, dtype='float64')
fc_out = paddle.static.nn.fc(x=data, size=2)
predict = paddle.nn.functional.softmax(x=fc_out)
result=paddle.static.auc(input=predict, label=label, ins_tag_weight=ins_tag_weight)
place = paddle.CPUPlace()
......@@ -207,10 +206,12 @@ def auc(input,
exe.run(paddle.static.default_startup_program())
x = np.random.rand(3,32,32).astype("float32")
y = np.array([1,0,1])
output= exe.run(feed={"input": x,"label": y},
z = np.array([1,0,1])
output= exe.run(feed={"input": x,"label": y, "ins_tag_weight":z},
fetch_list=[result[0]])
print(output)
#[array([0.5])]
'''
"""
helper = LayerHelper("auc", **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册