Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleFL
提交
ccd096c7
P
PaddleFL
项目概览
PaddlePaddle
/
PaddleFL
通知
35
Star
5
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
6
列表
看板
标记
里程碑
合并请求
4
Wiki
3
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleFL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
6
Issue
6
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
3
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ccd096c7
编写于
9月 11, 2020
作者:
R
root
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' of
https://github.com/PaddlePaddle/PaddleFL
into smc612
上级
1aec36bc
398d77e8
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
819 addition
and
6 deletion
+819
-6
core/paddlefl_mpc/mpc_protocol/aby3_operators.h
core/paddlefl_mpc/mpc_protocol/aby3_operators.h
+39
-0
core/paddlefl_mpc/mpc_protocol/mpc_operators.h
core/paddlefl_mpc/mpc_protocol/mpc_operators.h
+10
-0
core/paddlefl_mpc/operators/CMakeLists.txt
core/paddlefl_mpc/operators/CMakeLists.txt
+2
-1
core/paddlefl_mpc/operators/metrics/precision_recall_op.cc
core/paddlefl_mpc/operators/metrics/precision_recall_op.cc
+170
-0
core/paddlefl_mpc/operators/metrics/precision_recall_op.h
core/paddlefl_mpc/operators/metrics/precision_recall_op.h
+68
-0
core/privc3/fixedpoint_tensor.h
core/privc3/fixedpoint_tensor.h
+15
-3
core/privc3/fixedpoint_tensor_imp.h
core/privc3/fixedpoint_tensor_imp.h
+169
-1
core/privc3/fixedpoint_tensor_test.cc
core/privc3/fixedpoint_tensor_test.cc
+85
-0
python/paddle_fl/mpc/layers/__init__.py
python/paddle_fl/mpc/layers/__init__.py
+4
-0
python/paddle_fl/mpc/layers/metric_op.py
python/paddle_fl/mpc/layers/metric_op.py
+135
-0
python/paddle_fl/mpc/metrics.py
python/paddle_fl/mpc/metrics.py
+1
-1
python/paddle_fl/mpc/tests/unittests/run_test_example.sh
python/paddle_fl/mpc/tests/unittests/run_test_example.sh
+1
-0
python/paddle_fl/mpc/tests/unittests/test_op_metric.py
python/paddle_fl/mpc/tests/unittests/test_op_metric.py
+120
-0
未找到文件。
core/paddlefl_mpc/mpc_protocol/aby3_operators.h
浏览文件 @
ccd096c7
...
@@ -338,6 +338,45 @@ public:
...
@@ -338,6 +338,45 @@ public:
x_
->
inverse_square_root
(
y_
);
x_
->
inverse_square_root
(
y_
);
}
}
// only support pred for 1 in binary classification for now
void
predicts_to_indices
(
const
Tensor
*
in
,
Tensor
*
out
,
float
threshold
=
0.5
)
override
{
auto
x_tuple
=
from_tensor
(
in
);
auto
x_
=
std
::
get
<
0
>
(
x_tuple
).
get
();
auto
y_tuple
=
from_tensor
(
out
);
auto
y_
=
std
::
get
<
0
>
(
y_tuple
).
get
();
FixedTensor
::
preds_to_indices
(
x_
,
y_
,
threshold
);
}
void
calc_tp_fp_fn
(
const
Tensor
*
indices
,
const
Tensor
*
labels
,
Tensor
*
out
)
override
{
auto
idx_tuple
=
from_tensor
(
indices
);
auto
idx
=
std
::
get
<
0
>
(
idx_tuple
).
get
();
auto
lbl_tuple
=
from_tensor
(
labels
);
auto
lbl
=
std
::
get
<
0
>
(
lbl_tuple
).
get
();
auto
out_tuple
=
from_tensor
(
out
);
auto
out_
=
std
::
get
<
0
>
(
out_tuple
).
get
();
FixedTensor
::
calc_tp_fp_fn
(
idx
,
lbl
,
out_
);
}
void
calc_precision_recall
(
const
Tensor
*
tp_fp_fn
,
Tensor
*
out
)
override
{
auto
in_tuple
=
from_tensor
(
tp_fp_fn
);
auto
in
=
std
::
get
<
0
>
(
in_tuple
).
get
();
PaddleTensor
out_
(
ContextHolder
::
device_ctx
(),
*
out
);
out_
.
scaling_factor
()
=
ABY3_SCALING_FACTOR
;
FixedTensor
::
calc_precision_recall
(
in
,
&
out_
);
}
private:
private:
template
<
typename
T
>
template
<
typename
T
>
std
::
tuple
<
std
::
tuple
<
...
...
core/paddlefl_mpc/mpc_protocol/mpc_operators.h
浏览文件 @
ccd096c7
...
@@ -83,6 +83,16 @@ public:
...
@@ -83,6 +83,16 @@ public:
virtual
void
max_pooling
(
const
Tensor
*
in
,
Tensor
*
out
,
Tensor
*
pos_info
)
{}
virtual
void
max_pooling
(
const
Tensor
*
in
,
Tensor
*
out
,
Tensor
*
pos_info
)
{}
virtual
void
inverse_square_root
(
const
Tensor
*
in
,
Tensor
*
out
)
=
0
;
virtual
void
inverse_square_root
(
const
Tensor
*
in
,
Tensor
*
out
)
=
0
;
virtual
void
predicts_to_indices
(
const
Tensor
*
in
,
Tensor
*
out
,
float
threshold
=
0.5
)
=
0
;
virtual
void
calc_tp_fp_fn
(
const
Tensor
*
indices
,
const
Tensor
*
labels
,
Tensor
*
out
)
=
0
;
virtual
void
calc_precision_recall
(
const
Tensor
*
tp_fp_fn
,
Tensor
*
out
)
=
0
;
};
};
}
// mpc
}
// mpc
...
...
core/paddlefl_mpc/operators/CMakeLists.txt
浏览文件 @
ccd096c7
aux_source_directory
(
. DIR_SRCS
)
aux_source_directory
(
. DIR_SRCS
)
aux_source_directory
(
./math MATH_SRCS
)
aux_source_directory
(
./math MATH_SRCS
)
add_library
(
mpc_ops_o OBJECT
${
DIR_SRCS
}
${
MATH_SRCS
}
)
aux_source_directory
(
./metrics METRICS_SRCS
)
add_library
(
mpc_ops_o OBJECT
${
DIR_SRCS
}
${
MATH_SRCS
}
${
METRICS_SRCS
}
)
add_dependencies
(
mpc_ops_o fluid_framework gloo
)
add_dependencies
(
mpc_ops_o fluid_framework gloo
)
add_library
(
mpc_ops STATIC $<TARGET_OBJECTS:mpc_ops_o>
)
add_library
(
mpc_ops STATIC $<TARGET_OBJECTS:mpc_ops_o>
)
...
...
core/paddlefl_mpc/operators/metrics/precision_recall_op.cc
0 → 100644
浏览文件 @
ccd096c7
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "precision_recall_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include <string>
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
MpcPrecisionRecallOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Predicts"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Predicts) should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Labels"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Labels) should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"BatchMetrics"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Output(BatchMetrics) should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"AccumMetrics"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Output(AccumMetrics) should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"AccumStatesInfo"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Output(AccumStatesInfo) should not be null."
));
int64_t
cls_num
=
static_cast
<
int64_t
>
(
ctx
->
Attrs
().
Get
<
int
>
(
"class_number"
));
PADDLE_ENFORCE_EQ
(
cls_num
,
1
,
platform
::
errors
::
InvalidArgument
(
"Only support predicts/labels for 1"
"in binary classification for now."
));
auto
preds_dims
=
ctx
->
GetInputDim
(
"Predicts"
);
auto
labels_dims
=
ctx
->
GetInputDim
(
"Labels"
);
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
preds_dims
,
labels_dims
,
platform
::
errors
::
InvalidArgument
(
"The dimension of Input(Predicts) and "
"Input(Labels) should be the same."
"But received (%d) != (%d)"
,
preds_dims
,
labels_dims
));
PADDLE_ENFORCE_EQ
(
labels_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"Only support predicts/labels for 1"
"in binary classification for now."
"The dimension of Input(Labels) should be equal to 2 "
"(1 for shares). But received (%d)"
,
labels_dims
.
size
()));
}
if
(
ctx
->
HasInput
(
"StatesInfo"
))
{
auto
states_dims
=
ctx
->
GetInputDim
(
"StatesInfo"
);
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
states_dims
,
framework
::
make_ddim
({
2
,
3
}),
platform
::
errors
::
InvalidArgument
(
"The shape of Input(StatesInfo) should be [2, 3]."
));
}
}
// Layouts of BatchMetrics and AccumMetrics both are:
// [
// precision, recall, F1 score,
// ]
ctx
->
SetOutputDim
(
"BatchMetrics"
,
{
3
});
ctx
->
SetOutputDim
(
"AccumMetrics"
,
{
3
});
// Shape of AccumStatesInfo is [3]
// The layout of each row is:
// [ TP, FP, FN ]
ctx
->
SetOutputDim
(
"AccumStatesInfo"
,
{
2
,
3
});
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Predicts"
),
ctx
.
device_context
());
}
};
class
MpcPrecisionRecallOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Predicts"
,
"(Tensor, default Tensor<int64_t>) A 1-D tensor with shape N, "
"where N is the batch size. Each element contains the "
"corresponding predicts of an instance which computed by the "
"previous sigmoid operator."
);
AddInput
(
"Labels"
,
"(Tensor, default Tensor<int>) A 1-D tensor with shape N, "
"where N is the batch size. Each element is a label and the "
"value should be in [0, 1]."
);
AddInput
(
"StatesInfo"
,
"(Tensor, default Tensor<int>) A 1-D tensor with shape 3. "
"This input is optional. If provided, current state will be "
"accumulated to this state and the accumulation state will be "
"the output state."
)
.
AsDispensable
();
AddOutput
(
"BatchMetrics"
,
"(Tensor, default Tensor<int64_t>) A 1-D tensor with shape {3}. "
"This output tensor contains metrics for current batch data. "
"The layout is [precision, recall, f1 score]."
);
AddOutput
(
"AccumMetrics"
,
"(Tensor, default Tensor<int64_t>) A 1-D tensor with shape {3}. "
"This output tensor contains metrics for accumulated data. "
"The layout is [precision, recall, f1 score]."
);
AddOutput
(
"AccumStatesInfo"
,
"(Tensor, default Tensor<int64_t>) A 1-D tensor with shape 3. "
"This output tensor contains "
"accumulated state variables used to compute metrics. The layout "
"for each class is [true positives, false positives, "
"false negatives]."
);
AddAttr
<
int
>
(
"class_number"
,
"(int) Number of classes to be evaluated."
);
AddAttr
<
float
>
(
"threshold"
,
"(threshold) Threshold of true predict."
);
AddComment
(
R"DOC(
Precision Recall Operator.
When given Input(Indices) and Input(Labels), this operator can be used
to compute various metrics including:
1. precision
2. recall
3. f1 score
To compute the above metrics, we need to do statistics for true positives,
false positives and false negatives.
We define state as a 1-D tensor with shape [3]. Each element of a
state contains statistic variables for corresponding class. Layout of each row
is: TP(true positives), FP(false positives), FN(false negatives).
This operator also supports metrics computing for cross-batch situation. To
achieve this, Input(StatesInfo) should be provided. State of current batch
data will be accumulated to Input(StatesInfo) and Output(AccumStatesInfo)
is the accumulation state.
Output(BatchMetrics) is metrics of current batch data while
Output(AccumStatesInfo) is metrics of accumulation data.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
mpc_precision_recall
,
ops
::
MpcPrecisionRecallOp
,
ops
::
MpcPrecisionRecallOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
mpc_precision_recall
,
ops
::
MpcPrecisionRecallKernel
<
paddle
::
platform
::
CPUPlace
,
int64_t
>
);
core/paddlefl_mpc/operators/metrics/precision_recall_op.h
0 → 100644
浏览文件 @
ccd096c7
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "../mpc_op.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
>
class
MpcPrecisionRecallKernel
:
public
MpcOpKernel
<
T
>
{
public:
void
ComputeImpl
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
Tensor
*
preds
=
context
.
Input
<
Tensor
>
(
"Predicts"
);
const
Tensor
*
lbls
=
context
.
Input
<
Tensor
>
(
"Labels"
);
const
Tensor
*
stats
=
context
.
Input
<
Tensor
>
(
"StatesInfo"
);
Tensor
*
batch_metrics
=
context
.
Output
<
Tensor
>
(
"BatchMetrics"
);
Tensor
*
accum_metrics
=
context
.
Output
<
Tensor
>
(
"AccumMetrics"
);
Tensor
*
accum_stats
=
context
.
Output
<
Tensor
>
(
"AccumStatesInfo"
);
float
threshold
=
context
.
Attr
<
float
>
(
"threshold"
);
Tensor
idx
;
idx
.
mutable_data
<
T
>
(
preds
->
dims
(),
context
.
GetPlace
(),
0
);
Tensor
batch_stats
;
batch_stats
.
mutable_data
<
T
>
(
stats
->
dims
(),
context
.
GetPlace
(),
0
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
predicts_to_indices
(
preds
,
&
idx
,
threshold
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
calc_tp_fp_fn
(
&
idx
,
lbls
,
&
batch_stats
);
batch_metrics
->
mutable_data
<
T
>
(
framework
::
make_ddim
({
3
}),
context
.
GetPlace
(),
0
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
calc_precision_recall
(
&
batch_stats
,
batch_metrics
);
if
(
stats
)
{
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
add
(
&
batch_stats
,
stats
,
accum_stats
);
accum_metrics
->
mutable_data
<
T
>
(
framework
::
make_ddim
({
3
}),
context
.
GetPlace
(),
0
);
mpc
::
MpcInstance
::
mpc_instance
()
->
mpc_protocol
()
->
mpc_operators
()
->
calc_precision_recall
(
accum_stats
,
accum_metrics
);
}
}
};
}
// namespace operators
}
// namespace paddle
core/privc3/fixedpoint_tensor.h
浏览文件 @
ccd096c7
...
@@ -16,12 +16,10 @@
...
@@ -16,12 +16,10 @@
#include <vector>
#include <vector>
#include "boolean_tensor.h"
#include "aby3_context.h"
#include "aby3_context.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "paddle_tensor.h"
#include "paddle_tensor.h"
#include "boolean_tensor.h"
#include "boolean_tensor.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
namespace
aby3
{
namespace
aby3
{
...
@@ -193,6 +191,20 @@ public:
...
@@ -193,6 +191,20 @@ public:
void
max_pooling
(
FixedPointTensor
*
ret
,
void
max_pooling
(
FixedPointTensor
*
ret
,
BooleanTensor
<
T
>*
pos
=
nullptr
)
const
;
BooleanTensor
<
T
>*
pos
=
nullptr
)
const
;
// only support pred for 1 in binary classification for now
static
void
preds_to_indices
(
const
FixedPointTensor
*
preds
,
FixedPointTensor
*
indices
,
float
threshold
=
0.5
);
static
void
calc_tp_fp_fn
(
const
FixedPointTensor
*
indices
,
const
FixedPointTensor
*
labels
,
FixedPointTensor
*
tp_fp_fn
);
// clac precision_recall f1_score
// result is a plaintext fixed-point tensor, shape is [3]
static
void
calc_precision_recall
(
const
FixedPointTensor
*
tp_fp_fn
,
TensorAdapter
<
T
>*
ret
);
static
void
truncate
(
const
FixedPointTensor
*
op
,
FixedPointTensor
*
ret
,
static
void
truncate
(
const
FixedPointTensor
*
op
,
FixedPointTensor
*
ret
,
size_t
scaling_factor
);
size_t
scaling_factor
);
...
@@ -217,7 +229,7 @@ private:
...
@@ -217,7 +229,7 @@ private:
size_t
scaling_factor
);
size_t
scaling_factor
);
// reduce last dim
// reduce last dim
static
void
reduce
(
FixedPointTensor
<
T
,
N
>*
input
,
static
void
reduce
(
const
FixedPointTensor
<
T
,
N
>*
input
,
FixedPointTensor
<
T
,
N
>*
ret
);
FixedPointTensor
<
T
,
N
>*
ret
);
static
size_t
party
()
{
static
size_t
party
()
{
...
...
core/privc3/fixedpoint_tensor_imp.h
浏览文件 @
ccd096c7
...
@@ -847,7 +847,7 @@ void FixedPointTensor<T, N>::long_div(const FixedPointTensor<T, N>* rhs,
...
@@ -847,7 +847,7 @@ void FixedPointTensor<T, N>::long_div(const FixedPointTensor<T, N>* rhs,
// reduce last dim
// reduce last dim
template
<
typename
T
,
size_t
N
>
template
<
typename
T
,
size_t
N
>
void
FixedPointTensor
<
T
,
N
>::
reduce
(
FixedPointTensor
<
T
,
N
>*
input
,
void
FixedPointTensor
<
T
,
N
>::
reduce
(
const
FixedPointTensor
<
T
,
N
>*
input
,
FixedPointTensor
<
T
,
N
>*
ret
)
{
FixedPointTensor
<
T
,
N
>*
ret
)
{
//enfoce shape: input->shape[0 ... (n-2)] == ret shape
//enfoce shape: input->shape[0 ... (n-2)] == ret shape
auto
&
shape
=
input
->
shape
();
auto
&
shape
=
input
->
shape
();
...
@@ -1293,4 +1293,172 @@ void FixedPointTensor<T, N>::max_pooling(FixedPointTensor* ret,
...
@@ -1293,4 +1293,172 @@ void FixedPointTensor<T, N>::max_pooling(FixedPointTensor* ret,
}
}
template
<
typename
T
,
size_t
N
>
void
FixedPointTensor
<
T
,
N
>::
preds_to_indices
(
const
FixedPointTensor
*
preds
,
FixedPointTensor
*
indices
,
float
threshold
)
{
// 3 for allocating temp tensor
std
::
vector
<
std
::
shared_ptr
<
TensorAdapter
<
T
>>>
temp
;
for
(
size_t
i
=
0
;
i
<
3
;
++
i
)
{
temp
.
emplace_back
(
tensor_factory
()
->
template
create
<
T
>());
}
auto
shape_
=
preds
->
shape
();
// plaintext tensor for threshold
temp
[
0
]
->
reshape
(
shape_
);
temp
[
0
]
->
scaling_factor
()
=
N
;
assign_to_tensor
(
temp
[
0
].
get
(),
T
(
threshold
*
(
T
(
1
)
<<
N
)));
temp
[
1
]
->
reshape
(
shape_
);
temp
[
2
]
->
reshape
(
shape_
);
BooleanTensor
<
T
>
cmp_res
(
temp
[
1
].
get
(),
temp
[
2
].
get
());
preds
->
gt
(
temp
[
0
].
get
(),
&
cmp_res
);
cmp_res
.
lshift
(
N
,
&
cmp_res
);
cmp_res
.
b2a
(
indices
);
}
template
<
typename
T
,
size_t
N
>
void
FixedPointTensor
<
T
,
N
>::
calc_tp_fp_fn
(
const
FixedPointTensor
*
indices
,
const
FixedPointTensor
*
labels
,
FixedPointTensor
*
tp_fp_fn
)
{
PADDLE_ENFORCE_EQ
(
indices
->
shape
().
size
(),
1
,
"multi-classification not support yet"
);
PADDLE_ENFORCE_EQ
(
tp_fp_fn
->
shape
().
size
(),
1
,
"multi-classification not support yet"
);
PADDLE_ENFORCE_EQ
(
tp_fp_fn
->
shape
()[
0
],
3
,
"store tp fp fn for binary-classification only"
);
// 4 for allocating temp tensor
std
::
vector
<
std
::
shared_ptr
<
TensorAdapter
<
T
>>>
temp
;
for
(
size_t
i
=
0
;
i
<
4
;
++
i
)
{
temp
.
emplace_back
(
tensor_factory
()
->
template
create
<
T
>());
}
auto
shape_
=
indices
->
shape
();
std
::
vector
<
size_t
>
shape_one
=
{
1
};
std
::
vector
<
size_t
>
shape_3
=
{
3
};
temp
[
0
]
->
reshape
(
shape_
);
temp
[
1
]
->
reshape
(
shape_
);
FixedPointTensor
true_positive
(
temp
[
0
].
get
(),
temp
[
1
].
get
());
indices
->
mul
(
labels
,
&
true_positive
);
temp
[
2
]
->
reshape
(
shape_one
);
temp
[
3
]
->
reshape
(
shape_one
);
FixedPointTensor
scalar
(
temp
[
2
].
get
(),
temp
[
3
].
get
());
// tp
reduce
(
&
true_positive
,
&
scalar
);
const
T
&
share0
=
scalar
.
share
(
0
)
->
data
()[
0
];
const
T
&
share1
=
scalar
.
share
(
1
)
->
data
()[
0
];
T
*
ret_data0
=
tp_fp_fn
->
mutable_share
(
0
)
->
data
();
T
*
ret_data1
=
tp_fp_fn
->
mutable_share
(
1
)
->
data
();
// assgin tp
ret_data0
[
0
]
=
share0
;
ret_data1
[
0
]
=
share1
;
// tp + fp
reduce
(
indices
,
&
scalar
);
// direcrt aby3 sub
ret_data0
[
1
]
=
share0
-
ret_data0
[
0
];
ret_data1
[
1
]
=
share1
-
ret_data1
[
0
];
// tp + fn
reduce
(
labels
,
&
scalar
);
ret_data0
[
2
]
=
share0
-
ret_data0
[
0
];
ret_data1
[
2
]
=
share1
-
ret_data1
[
0
];
}
template
<
typename
T
,
size_t
N
>
void
FixedPointTensor
<
T
,
N
>::
calc_precision_recall
(
const
FixedPointTensor
*
tp_fp_fn
,
TensorAdapter
<
T
>*
ret
)
{
PADDLE_ENFORCE_EQ
(
tp_fp_fn
->
shape
().
size
(),
1
,
"multi-classification not support yet"
);
PADDLE_ENFORCE_EQ
(
tp_fp_fn
->
shape
()[
0
],
3
,
"store tp fp fn for binary-classification only"
);
PADDLE_ENFORCE_EQ
(
ret
->
shape
().
size
(),
1
,
"multi-classification not support yet"
);
PADDLE_ENFORCE_EQ
(
ret
->
shape
()[
0
],
3
,
"store precision recall f1-score"
"for binary-classification only"
);
// 5 for allocating temp tensor
std
::
vector
<
std
::
shared_ptr
<
TensorAdapter
<
T
>>>
temp
;
for
(
size_t
i
=
0
;
i
<
5
;
++
i
)
{
temp
.
emplace_back
(
tensor_factory
()
->
template
create
<
T
>());
}
std
::
vector
<
size_t
>
shape_
=
{
3
};
std
::
vector
<
size_t
>
shape_one
=
{
1
};
temp
[
0
]
->
reshape
(
shape_one
);
temp
[
1
]
->
reshape
(
shape_one
);
FixedPointTensor
scalar
(
temp
[
0
].
get
(),
temp
[
1
].
get
());
temp
[
2
]
->
reshape
(
shape_one
);
temp
[
3
]
->
reshape
(
shape_one
);
FixedPointTensor
scalar2
(
temp
[
2
].
get
(),
temp
[
3
].
get
());
auto
get
=
[
&
tp_fp_fn
](
size_t
idx
,
FixedPointTensor
*
dest
)
{
dest
->
mutable_share
(
0
)
->
data
()[
0
]
=
tp_fp_fn
->
share
(
0
)
->
data
()[
idx
];
dest
->
mutable_share
(
1
)
->
data
()[
0
]
=
tp_fp_fn
->
share
(
1
)
->
data
()[
idx
];
};
get
(
0
,
&
scalar
);
get
(
1
,
&
scalar2
);
// tp + fp
scalar
.
add
(
&
scalar2
,
&
scalar2
);
scalar
.
long_div
(
&
scalar2
,
&
scalar2
);
temp
[
4
]
->
reshape
(
shape_one
);
scalar2
.
reveal
(
temp
[
4
].
get
());
ret
->
scaling_factor
()
=
N
;
ret
->
data
()[
0
]
=
temp
[
4
]
->
data
()[
0
];
get
(
2
,
&
scalar2
);
// tp + fn
scalar
.
add
(
&
scalar2
,
&
scalar2
);
scalar
.
long_div
(
&
scalar2
,
&
scalar2
);
scalar2
.
reveal
(
temp
[
4
].
get
());
ret
->
data
()[
1
]
=
temp
[
4
]
->
data
()[
0
];
float
precision
=
1.0
*
ret
->
data
()[
0
]
/
(
T
(
1
)
<<
N
);
float
recall
=
1.0
*
ret
->
data
()[
1
]
/
(
T
(
1
)
<<
N
);
float
f1_score
=
0.0
;
if
(
precision
+
recall
>
0
)
{
f1_score
=
2
*
precision
*
recall
/
(
precision
+
recall
);
}
ret
->
data
()[
2
]
=
T
(
f1_score
*
(
T
(
1
)
<<
N
));
}
}
// namespace aby3
}
// namespace aby3
core/privc3/fixedpoint_tensor_test.cc
浏览文件 @
ccd096c7
...
@@ -898,6 +898,40 @@ void test_fixedt_matmul_fixed(size_t p,
...
@@ -898,6 +898,40 @@ void test_fixedt_matmul_fixed(size_t p,
result
->
reveal
(
out
);
result
->
reveal
(
out
);
}
}
void
test_fixedt_precision_recall_fixed
(
size_t
p
,
double
threshold
,
std
::
vector
<
std
::
shared_ptr
<
TensorAdapter
<
int64_t
>>>
in
,
TensorAdapter
<
int64_t
>*
out
)
{
std
::
vector
<
std
::
shared_ptr
<
TensorAdapter
<
int64_t
>>>
temp
;
// preds
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
temp
.
emplace_back
(
gen
(
in
[
0
]
->
shape
()));
}
// labels
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
temp
.
emplace_back
(
gen
(
in
[
1
]
->
shape
()));
}
// indices
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
temp
.
emplace_back
(
gen
(
in
[
0
]
->
shape
()));
}
std
::
vector
<
size_t
>
shape_
=
{
3
};
// tp fp fn
for
(
int
i
=
0
;
i
<
2
;
i
++
)
{
temp
.
emplace_back
(
gen
(
shape_
));
}
test_fixedt_gen_shares
(
p
,
in
,
temp
);
Fix64N16
*
preds
=
new
Fix64N16
(
temp
[
0
].
get
(),
temp
[
1
].
get
());
Fix64N16
*
labels
=
new
Fix64N16
(
temp
[
2
].
get
(),
temp
[
3
].
get
());
Fix64N16
*
indices
=
new
Fix64N16
(
temp
[
4
].
get
(),
temp
[
5
].
get
());
Fix64N16
*
tpfpfn
=
new
Fix64N16
(
temp
[
6
].
get
(),
temp
[
7
].
get
());
Fix64N16
::
preds_to_indices
(
preds
,
indices
,
threshold
);
Fix64N16
::
calc_tp_fp_fn
(
indices
,
labels
,
tpfpfn
);
Fix64N16
::
calc_precision_recall
(
tpfpfn
,
out
);
}
TEST_F
(
FixedTensorTest
,
matmulfixed
)
{
TEST_F
(
FixedTensorTest
,
matmulfixed
)
{
std
::
vector
<
size_t
>
shape
=
{
1
,
3
};
std
::
vector
<
size_t
>
shape
=
{
1
,
3
};
...
@@ -3559,4 +3593,55 @@ TEST_F(FixedTensorTest, truncate3_msb_correct) {
...
@@ -3559,4 +3593,55 @@ TEST_F(FixedTensorTest, truncate3_msb_correct) {
}
}
#endif
#endif
TEST_F
(
FixedTensorTest
,
precision_recall
)
{
std
::
vector
<
size_t
>
shape
=
{
6
};
std
::
vector
<
size_t
>
shape_o
=
{
3
};
std
::
vector
<
double
>
in0_val
=
{
0.0
,
0.2
,
0.4
,
0.6
,
0.8
,
1.0
};
std
::
vector
<
double
>
in1_val
=
{
0
,
1
,
0
,
1
,
0
,
1
};
std
::
vector
<
double
>
res_val
=
{
0.5
,
1.0
/
3
,
0.4
};
double
threshold
=
0.7
;
std
::
vector
<
std
::
shared_ptr
<
TensorAdapter
<
int64_t
>>>
in
=
{
gen
(
shape
),
gen
(
shape
)};
test_fixedt_gen_paddle_tensor
<
int64_t
,
16
>
(
in0_val
,
shape
,
_cpu_ctx
).
copy
(
in
[
0
].
get
());
test_fixedt_gen_paddle_tensor
<
int64_t
,
16
>
(
in1_val
,
shape
,
_cpu_ctx
).
copy
(
in
[
1
].
get
());
auto
out0
=
_s_tensor_factory
->
create
<
int64_t
>
(
shape_o
);
auto
out1
=
_s_tensor_factory
->
create
<
int64_t
>
(
shape_o
);
auto
out2
=
_s_tensor_factory
->
create
<
int64_t
>
(
shape_o
);
PaddleTensor
<
int64_t
>
result
=
test_fixedt_gen_paddle_tensor
<
int64_t
,
16
>
(
res_val
,
shape_o
,
_cpu_ctx
);
_t
[
0
]
=
std
::
thread
([
this
,
in
,
out0
,
threshold
]()
mutable
{
g_ctx_holder
::
template
run_with_context
(
_exec_ctx
.
get
(),
_mpc_ctx
[
0
],
[
&
]()
{
test_fixedt_precision_recall_fixed
(
0
,
threshold
,
in
,
out0
.
get
());
});
});
_t
[
1
]
=
std
::
thread
([
this
,
in
,
out1
,
threshold
]()
mutable
{
g_ctx_holder
::
template
run_with_context
(
_exec_ctx
.
get
(),
_mpc_ctx
[
1
],
[
&
]()
{
test_fixedt_precision_recall_fixed
(
1
,
threshold
,
in
,
out1
.
get
());
});
}
);
_t
[
2
]
=
std
::
thread
([
this
,
in
,
out2
,
threshold
]()
mutable
{
g_ctx_holder
::
template
run_with_context
(
_exec_ctx
.
get
(),
_mpc_ctx
[
2
],
[
&
]()
{
test_fixedt_precision_recall_fixed
(
2
,
threshold
,
in
,
out2
.
get
());
});
}
);
_t
[
0
].
join
();
_t
[
1
].
join
();
_t
[
2
].
join
();
EXPECT_TRUE
(
test_fixedt_check_tensor_eq
(
out0
.
get
(),
out1
.
get
()));
EXPECT_TRUE
(
test_fixedt_check_tensor_eq
(
out1
.
get
(),
out2
.
get
()));
EXPECT_TRUE
(
test_fixedt_check_tensor_eq
(
out0
.
get
(),
&
result
));
}
}
// namespace aby3
}
// namespace aby3
python/paddle_fl/mpc/layers/__init__.py
浏览文件 @
ccd096c7
...
@@ -18,6 +18,7 @@ mpc layers:
...
@@ -18,6 +18,7 @@ mpc layers:
matrix: 'mul'
matrix: 'mul'
ml: 'fc', 'relu', 'softmax'(todo)
ml: 'fc', 'relu', 'softmax'(todo)
compare:'greater_than', 'greater_equal', 'less_than', 'less_equal', 'equal', 'not_equal'
compare:'greater_than', 'greater_equal', 'less_than', 'less_equal', 'equal', 'not_equal'
metric_op:'precision_recall'
"""
"""
from
.
import
basic
from
.
import
basic
...
@@ -34,6 +35,8 @@ from . import conv
...
@@ -34,6 +35,8 @@ from . import conv
from
.conv
import
conv2d
from
.conv
import
conv2d
from
.
import
rnn
from
.
import
rnn
from
.rnn
import
*
from
.rnn
import
*
from
.
import
metric_op
from
.metric_op
import
*
__all__
=
[]
__all__
=
[]
__all__
+=
basic
.
__all__
__all__
+=
basic
.
__all__
...
@@ -42,3 +45,4 @@ __all__ += matrix.__all__
...
@@ -42,3 +45,4 @@ __all__ += matrix.__all__
__all__
+=
ml
.
__all__
__all__
+=
ml
.
__all__
__all__
+=
compare
.
__all__
__all__
+=
compare
.
__all__
__all__
+=
conv
.
__all__
__all__
+=
conv
.
__all__
__all__
+=
metric_op
.
__all__
python/paddle_fl/mpc/layers/metric_op.py
0 → 100644
浏览文件 @
ccd096c7
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
mpc metric op layers.
"""
from
paddle.fluid.data_feeder
import
check_type
,
check_dtype
from
paddle.fluid.initializer
import
Constant
from
..framework
import
check_mpc_variable_and_dtype
from
..mpc_layer_helper
import
MpcLayerHelper
__all__
=
[
'precision_recall'
]
def
precision_recall
(
input
,
label
,
threshold
=
0.5
):
"""
Precision (also called positive predictive value) is the fraction of
relevant instances among the retrieved instances.
Recall (also known as sensitivity) is the fraction of
relevant instances that have been retrieved over the
total amount of relevant instances
F1-score is a measure of a test's accuracy.
It is calculated from the precision and recall of the test.
Refer to:
https://en.wikipedia.org/wiki/Precision_and_recall
https://en.wikipedia.org/wiki/F1_score
Noted that this class manages the metrics only for binary classification task.
Noted that in both precision and recall, define 0/0 equals to 0.
Args:
input (Variable): ciphtext predicts for 1 in binary classification.
label (Variable): labels in ciphertext.
threshold (float): predict threshold.
Returns:
batch_out (Variable): plaintext of batch metrics [precision, recall, f1-score]
Note that values in batch_out are fixed-point number.
To get float type values, div fetched batch_out by
3 * mpc_data_utils.mpc_one_share (which equals to 2**16).
acc_out (Variable): plaintext of accumulated metrics [precision, recall, f1-score]
To get float type values, div fetched acc_out by
3 * mpc_data_utils.mpc_one_share (which equals to 2**16).
Examples:
.. code-block:: python
import sys
import numpy as np
import paddle.fluid as fluid
import paddle_fl.mpc as pfl_mpc
import mpc_data_utils as mdu
role = int(sys.argv[1])
redis_server = "127.0.0.1"
redis_port = 9937
loop = 5
np.random.seed(0)
input_size = [100]
threshold = 0.6
preds, labels = [], []
preds_cipher, labels_cipher = [], []
#simulating mpc share
share = lambda x: np.array([x * mdu.mpc_one_share] * 2).astype('int64').reshape([2] + input_size)
for _ in range(loop):
preds.append(np.random.random(input_size))
labels.append(np.rint(np.random.random(input_size)))
preds_cipher.append(share(preds[-1]))
labels_cipher.append(share(labels[-1]))
pfl_mpc.init("aby3", role, "localhost", redis_server, redis_port)
x = pfl_mpc.data(name='x', shape=input_size, dtype='int64')
y = pfl_mpc.data(name='y', shape=input_size, dtype='int64')
out0, out1 = pfl_mpc.layers.precision_recall(input=x, label=y, threshold=threshold)
exe = fluid.Executor(place=fluid.CPUPlace())
exe.run(fluid.default_startup_program())
for i in range(loop):
batch_res, acc_res = exe.run(feed={'x': preds_cipher[i], 'y': labels_cipher[i]},
fetch_list=[out0, out1])
fixed_point_one = 3.0 * mdu.mpc_one_share
# result could be varified by calcuatling metrics with plaintext preds, labels
print(batch_res / fixed_point_one , acc_res / fixed_point_one)
"""
helper
=
MpcLayerHelper
(
"precision_recall"
,
**
locals
())
dtype
=
helper
.
input_dtype
()
check_dtype
(
dtype
,
'input'
,
[
'int64'
],
'precision_recall'
)
check_dtype
(
dtype
,
'label'
,
[
'int64'
],
'precision_recall'
)
batch_out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
=
input
.
dtype
)
acc_out
=
helper
.
create_mpc_variable_for_type_inference
(
dtype
=
input
.
dtype
)
stat
=
helper
.
create_global_mpc_variable
(
persistable
=
True
,
dtype
=
'int64'
,
shape
=
[
3
],
)
helper
.
set_variable_initializer
(
stat
,
Constant
(
value
=
0
))
op_type
=
'precision_recall'
helper
.
append_op
(
type
=
'mpc_'
+
op_type
,
inputs
=
{
"Predicts"
:
input
,
"Labels"
:
label
,
"StatesInfo"
:
stat
,
},
outputs
=
{
"BatchMetrics"
:
batch_out
,
"AccumMetrics"
:
acc_out
,
"AccumStatesInfo"
:
stat
,
},
attrs
=
{
"threshold"
:
threshold
,
"class_number"
:
1
,
})
return
batch_out
,
acc_out
python/paddle_fl/mpc/metrics.py
浏览文件 @
ccd096c7
...
@@ -34,7 +34,7 @@ def _is_numpy_(var):
...
@@ -34,7 +34,7 @@ def _is_numpy_(var):
class
KSstatistic
(
MetricBase
):
class
KSstatistic
(
MetricBase
):
"""
"""
The is for binary classification.
The
KSstatistic
is for binary classification.
Refer to https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test#Kolmogorov%E2%80%93Smirnov_statistic
Refer to https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test#Kolmogorov%E2%80%93Smirnov_statistic
Please notice that the KS statistic is implemented with scipy.
Please notice that the KS statistic is implemented with scipy.
...
...
python/paddle_fl/mpc/tests/unittests/run_test_example.sh
浏览文件 @
ccd096c7
...
@@ -25,6 +25,7 @@ TEST_MODULES=("test_datautils_aby3"
...
@@ -25,6 +25,7 @@ TEST_MODULES=("test_datautils_aby3"
"test_op_batch_norm"
"test_op_batch_norm"
"test_op_conv"
"test_op_conv"
"test_op_pool"
"test_op_pool"
"test_op_metric"
)
)
# run unittest
# run unittest
...
...
python/paddle_fl/mpc/tests/unittests/test_op_metric.py
0 → 100644
浏览文件 @
ccd096c7
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module test metric op.
"""
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle_fl.mpc
as
pfl_mpc
import
test_op_base
def
precision_recall_naive
(
input
,
label
,
threshold
=
0.5
,
stat
=
None
):
pred
=
input
-
(
threshold
-
0.5
)
pred
=
np
.
maximum
(
0
,
pred
)
pred
=
np
.
minimum
(
1
,
pred
)
idx
=
np
.
rint
(
pred
)
tp
=
np
.
sum
(
idx
*
label
)
fp
=
np
.
sum
(
idx
)
-
tp
fn
=
np
.
sum
(
label
)
-
tp
def
calc_precision
(
tp
,
fp
):
return
tp
/
(
tp
+
fp
)
if
tp
+
fp
>
0
else
0.0
def
calc_recall
(
tp
,
fn
):
return
tp
/
(
tp
+
fn
)
if
tp
+
fn
>
0
else
0.0
def
calc_f1
(
precision
,
recall
):
return
2
*
precision
*
recall
/
(
precision
+
recall
)
if
precision
+
recall
>
0
else
0.0
p_batch
,
r_batch
=
calc_precision
(
tp
,
fp
),
calc_recall
(
tp
,
fn
)
f_batch
=
calc_f1
(
p_batch
,
r_batch
)
p_acc
,
r_acc
,
f_acc
=
p_batch
,
r_batch
,
f_batch
if
stat
:
tp
+=
stat
[
0
]
fp
+=
stat
[
1
]
fn
+=
stat
[
2
]
p_acc
,
r_acc
=
calc_precision
(
tp
,
fp
),
calc_recall
(
tp
,
fn
)
f_acc
=
calc_f1
(
p_acc
,
r_acc
)
new_stat
=
[
tp
,
fp
,
fn
]
return
np
.
array
([
p_batch
,
r_batch
,
f_batch
,
p_acc
,
r_acc
,
f_acc
]),
new_stat
class
TestOpPrecisionRecall
(
test_op_base
.
TestOpBase
):
def
precision_recall
(
self
,
**
kwargs
):
"""
precision_recall op ut
:param kwargs:
:return:
"""
role
=
kwargs
[
'role'
]
preds
=
kwargs
[
'preds'
]
labels
=
kwargs
[
'labels'
]
loop
=
kwargs
[
'loop'
]
pfl_mpc
.
init
(
"aby3"
,
role
,
"localhost"
,
self
.
server
,
int
(
self
.
port
))
x
=
pfl_mpc
.
data
(
name
=
'x'
,
shape
=
self
.
input_size
,
dtype
=
'int64'
)
y
=
pfl_mpc
.
data
(
name
=
'y'
,
shape
=
self
.
input_size
,
dtype
=
'int64'
)
out0
,
out1
=
pfl_mpc
.
layers
.
precision_recall
(
input
=
x
,
label
=
y
,
threshold
=
self
.
threshold
)
exe
=
fluid
.
Executor
(
place
=
fluid
.
CPUPlace
())
exe
.
run
(
fluid
.
default_startup_program
())
for
i
in
range
(
loop
):
batch_res
,
acc_res
=
exe
.
run
(
feed
=
{
'x'
:
preds
[
i
],
'y'
:
labels
[
i
]},
fetch_list
=
[
out0
,
out1
])
self
.
assertTrue
(
np
.
allclose
(
batch_res
*
(
2
**
-
16
),
self
.
exp_res
[
0
][:
3
],
atol
=
1e-4
))
self
.
assertTrue
(
np
.
allclose
(
acc_res
*
(
2
**
-
16
),
self
.
exp_res
[
0
][
3
:],
atol
=
1e-4
))
def
n_batch_test
(
self
,
n
):
self
.
input_size
=
[
100
]
self
.
threshold
=
np
.
random
.
random
()
preds
,
labels
=
[],
[]
self
.
exp_res
=
(
0
,
[
0
]
*
3
)
share
=
lambda
x
:
np
.
array
([
x
*
65536
/
3
]
*
2
).
astype
(
'int64'
).
reshape
(
[
2
]
+
self
.
input_size
)
for
_
in
range
(
n
):
preds
.
append
(
np
.
random
.
random
(
self
.
input_size
))
labels
.
append
(
np
.
rint
(
np
.
random
.
random
(
self
.
input_size
)))
self
.
exp_res
=
precision_recall_naive
(
preds
[
-
1
],
labels
[
-
1
],
self
.
threshold
,
self
.
exp_res
[
-
1
])
preds
[
-
1
]
=
share
(
preds
[
-
1
])
labels
[
-
1
]
=
share
(
labels
[
-
1
])
ret
=
self
.
multi_party_run
(
target
=
self
.
precision_recall
,
preds
=
preds
,
labels
=
labels
,
loop
=
n
)
self
.
assertEqual
(
ret
[
0
],
True
)
def
test_1
(
self
):
self
.
n_batch_test
(
1
)
def
test_2
(
self
):
self
.
n_batch_test
(
2
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录