Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleFL
提交
d3b64d89
P
PaddleFL
项目概览
PaddlePaddle
/
PaddleFL
通知
35
Star
5
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
6
列表
看板
标记
里程碑
合并请求
4
Wiki
3
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleFL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
6
Issue
6
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
3
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d3b64d89
编写于
8月 28, 2020
作者:
J
jed
提交者:
GitHub
8月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #109 from kaih70/master
add ks statistic & fix prng memory bug
上级
24953638
6a76f2dc
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
111 addition
and
1 deletion
+111
-1
core/psi/prng.h
core/psi/prng.h
+1
-1
python/paddle_fl/mpc/metrics.py
python/paddle_fl/mpc/metrics.py
+110
-0
未找到文件。
core/psi/prng.h
浏览文件 @
d3b64d89
...
...
@@ -63,7 +63,7 @@ public:
private:
// buffer num for aes cipher
static
const
size_t
_s_buffer_size
=
0x10000
0
;
static
const
size_t
_s_buffer_size
=
0x10000
;
static
const
size_t
_s_byte_capacity
=
_s_buffer_size
*
sizeof
(
block
);
...
...
python/paddle_fl/mpc/metrics.py
0 → 100644
浏览文件 @
d3b64d89
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MPC Metrics
"""
from
paddle.fluid.metrics
import
MetricBase
import
numpy
as
np
import
scipy
__all__
=
[
'KSstatistic'
,
]
def
_is_numpy_
(
var
):
return
isinstance
(
var
,
(
np
.
ndarray
,
np
.
generic
))
class
KSstatistic
(
MetricBase
):
"""
The is for binary classification.
Refer to https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test#Kolmogorov%E2%80%93Smirnov_statistic
Please notice that the KS statistic is implemented with scipy.
The `KSstatistic` function creates 2 local variables, `data1`, `data2`
which is predictions of positive samples and negative samples respectively
that are used to compute the KS statistic.
Args:
name (str, optional): Metric name. For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
import paddle_fl.mpc
import numpy as np
# init the KSstatistic
ks = paddle_fl.mpc.metrics.KSstatistic('ks')
# suppose that batch_size is 128
batch_num = 100
batch_size = 128
for batch_id in range(batch_num):
class0_preds = np.random.random(size = (batch_size, 1))
class1_preds = 1 - class0_preds
preds = np.concatenate((class0_preds, class1_preds), axis=1)
labels = np.random.randint(2, size = (batch_size, 1))
ks.update(preds = preds, labels = labels)
# shall be some score closing to 0.1 as the preds are randomly assigned
print("ks statistic for iteration %d is %.2f" % (batch_id, ks.eval()))
"""
def
__init__
(
self
,
name
=
None
):
super
(
KSstatistic
,
self
).
__init__
(
name
=
name
)
self
.
_data1
=
[]
self
.
_data2
=
[]
def
update
(
self
,
preds
,
labels
):
"""
Update the auc curve with the given predictions and labels.
Args:
preds (numpy.array): an numpy array in the shape of
(batch_size, 2), preds[i][j] denotes the probability of
classifying the instance i into the class j.
labels (numpy.array): an numpy array in the shape of
(batch_size, 1), labels[i] is either o or 1, representing
the label of the instance i.
"""
if
not
_is_numpy_
(
labels
):
raise
ValueError
(
"The 'labels' must be a numpy ndarray."
)
if
not
_is_numpy_
(
preds
):
raise
ValueError
(
"The 'predictions' must be a numpy ndarray."
)
data1
=
[
preds
[
i
,
1
]
for
i
,
lbl
in
enumerate
(
labels
)
if
lbl
]
data2
=
[
preds
[
i
,
1
]
for
i
,
lbl
in
enumerate
(
labels
)
if
not
lbl
]
self
.
_data1
+=
data1
self
.
_data2
+=
data2
def
eval
(
self
):
"""
Return the area (a float score) under auc curve
Return:
float: the area under auc curve
"""
return
scipy
.
stats
.
ks_2samp
(
self
.
_data1
,
self
.
_data2
).
statistic
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录