提交 bee67d95 编写于 作者: M malin10

add metrics

上级 469afd42
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .precision import Precision
from .recall_k import RecallK
from .pairwise_pn import PosNegRatio
import binary_class
__all__ = ['Precision']
__all__ = ['RecallK', 'PosNegRatio'] + binary_class.__all__
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .auc import AUC
from .precision_recall import PrecisionRecall
__all__ = ['PrecisionRecall', 'AUC']
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle.fluid as fluid
from paddlerec.core.metric import Metric
from paddle.fluid.layers import nn, accuracy
from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper
class AUC(Metric):
"""
Metric For Fluid Model
"""
def __init__(self, **kwargs):
""" """
predict = kwargs.get("input")
label = kwargs.get("label")
curve = kwargs.get("curve", 'ROC')
num_thresholds = kwargs.get("num_thresholds", 2**12 - 1)
topk = kwargs.get("topk", 1)
slide_steps = kwargs.get("slide_steps", 1)
auc_out, batch_auc_out, [
batch_stat_pos, batch_stat_neg, stat_pos, stat_neg
] = fluid.layers.auc(predict,
label,
curve=curve,
num_thresholds=num_thresholds,
topk=topk,
slide_steps=slide_steps)
self._need_clear_list = [(stat_pos.name, "float32"),
(stat_neg.name, "float32")]
self.metrics = dict()
self.metrics["AUC"] = auc_out
self.metrics["BATCH_AUC"] = batch_auc_out
def get_result(self):
return self.metrics
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle.fluid as fluid
from paddlerec.core.metric import Metric
from paddle.fluid.layers import nn, accuracy
from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper
class PrecisionRecall(Metric):
"""
Metric For Fluid Model
"""
def __init__(self, **kwargs):
""" """
helper = LayerHelper("PaddleRec_PrecisionRecall", **kwargs)
predict = kwargs.get("input")
origin_label = kwargs.get("label")
label = fluid.layers.cast(origin_label, dtype="int32")
label.stop_gradient = True
num_cls = kwargs.get("class_num")
max_probs, indices = fluid.layers.nn.topk(predict, k=1)
indices = fluid.layers.cast(indices, dtype="int32")
indices.stop_gradient = True
states_info, _ = helper.create_or_get_global_variable(
name="states_info",
persistable=True,
dtype='float32',
shape=[num_cls, 4])
states_info.stop_gradient = True
helper.set_variable_initializer(
states_info, Constant(
value=0.0, force_cpu=True))
batch_metrics, _ = helper.create_or_get_global_variable(
name="batch_metrics",
persistable=False,
dtype='float32',
shape=[6])
accum_metrics, _ = helper.create_or_get_global_variable(
name="global_metrics",
persistable=False,
dtype='float32',
shape=[6])
batch_states = fluid.layers.fill_constant(
shape=[num_cls, 4], value=0.0, dtype="float32")
batch_states.stop_gradient = True
helper.append_op(
type="precision_recall",
attrs={'class_number': num_cls},
inputs={
'MaxProbs': [max_probs],
'Indices': [indices],
'Labels': [label],
'StatesInfo': [states_info]
},
outputs={
'BatchMetrics': [batch_metrics],
'AccumMetrics': [accum_metrics],
'AccumStatesInfo': [batch_states]
})
helper.append_op(
type="assign",
inputs={'X': [batch_states]},
outputs={'Out': [states_info]})
batch_states.stop_gradient = True
states_info.stop_gradient = True
self._need_clear_list = [("states_info", "float32")]
self.metrics = dict()
self.metrics["precision_recall_f1"] = accum_metrics
self.metrics["accum_states"] = states_info
# self.metrics["batch_metrics"] = batch_metrics
def get_result(self):
return self.metrics
......@@ -23,87 +23,53 @@ from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper
class Precision(Metric):
class PosNegRatio(Metric):
"""
Metric For Fluid Model
"""
def __init__(self, **kwargs):
""" """
helper = LayerHelper("PaddleRec_Precision", **kwargs)
self.batch_accuracy = accuracy(
kwargs.get("input"), kwargs.get("label"), kwargs.get("k"))
local_ins_num, _ = helper.create_or_get_global_variable(
name="local_ins_num", persistable=True, dtype='float32',
shape=[1])
local_pos_num, _ = helper.create_or_get_global_variable(
name="local_pos_num", persistable=True, dtype='float32',
shape=[1])
batch_pos_num, _ = helper.create_or_get_global_variable(
name="batch_pos_num",
persistable=False,
dtype='float32',
shape=[1])
batch_ins_num, _ = helper.create_or_get_global_variable(
name="batch_ins_num",
persistable=False,
dtype='float32',
shape=[1])
tmp_ones = helper.create_global_variable(
name="batch_size_like_ones",
persistable=False,
dtype='float32',
shape=[-1])
for var in [
batch_pos_num, batch_ins_num, local_pos_num, local_ins_num
]:
print(var, type(var))
helper = LayerHelper("PaddleRec_PosNegRatio", **kwargs)
pos_score = kwargs.get('pos_score')
neg_score = kwargs.get('neg_score')
wrong = fluid.layers.cast(
fluid.layers.less_equal(pos_score, neg_score), dtype='float32')
wrong_cnt = fluid.layers.reduce_sum(wrong)
right = fluid.layers.cast(
fluid.layers.less_than(neg_score, pos_score), dtype='float32')
right_cnt = fluid.layers.reduce_sum(right)
global_right_cnt, _ = helper.create_or_get_global_variable(
name="right_cnt", persistable=True, dtype='float32', shape=[1])
global_wrong_cnt, _ = helper.create_or_get_global_variable(
name="wrong_cnt", persistable=True, dtype='float32', shape=[1])
for var in [global_right_cnt, global_wrong_cnt]:
helper.set_variable_initializer(
var, Constant(
value=0.0, force_cpu=True))
helper.append_op(
type='fill_constant_batch_size_like',
inputs={"Input": kwargs.get("label")},
outputs={'Out': [tmp_ones]},
attrs={
'shape': [-1, 1],
'dtype': tmp_ones.dtype,
'value': float(1.0),
})
helper.append_op(
type="reduce_sum",
inputs={"X": [tmp_ones]},
outputs={"Out": [batch_ins_num]})
helper.append_op(
type="elementwise_mul",
inputs={"X": [batch_ins_num],
"Y": [self.batch_accuracy]},
outputs={"Out": [batch_pos_num]})
helper.append_op(
type="elementwise_add",
inputs={"X": [local_pos_num],
"Y": [batch_pos_num]},
outputs={"Out": [local_pos_num]})
inputs={"X": [global_right_cnt],
"Y": [right_cnt]},
outputs={"Out": [global_right_cnt]})
helper.append_op(
type="elementwise_add",
inputs={"X": [local_ins_num],
"Y": [batch_ins_num]},
outputs={"Out": [local_ins_num]})
inputs={"X": [global_wrong_cnt],
"Y": [wrong_cnt]},
outputs={"Out": [global_wrong_cnt]})
self.pn = (global_right_cnt + 1.0) / (global_wrong_cnt + 1.0)
self.accuracy = local_pos_num / local_ins_num
self._need_clear_list = [("right_cnt", "float32"),
("wrong_cnt", "float32")]
self._need_clear_list = [("local_ins_num", "float32"),
("local_pos_num", "float32")]
self.metrics = dict()
metric_varname = "P@%d" % kwargs.get("k")
self.metrics[metric_varname] = self.accuracy
self.metrics['wrong_cnt'] = global_wrong_cnt
self.metrics['right_cnt'] = global_right_cnt
self.metrics['pos_neg_ratio'] = self.pn
def get_result(self):
return self.metrics
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle.fluid as fluid
from paddlerec.core.metric import Metric
from paddle.fluid.layers import nn, accuracy
from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper
class RecallK(Metric):
"""
Metric For Fluid Model
"""
def __init__(self, **kwargs):
""" """
helper = LayerHelper("PaddleRec_RecallK", **kwargs)
predict = kwargs.get("input")
origin_label = kwargs.get("label")
label = fluid.layers.cast(origin_label, dtype="int32")
label.stop_gradient = True
num_cls = kwargs.get("class_num")
max_probs, indices = fluid.layers.nn.topk(predict, k=1)
indices = fluid.layers.cast(indices, dtype="int32")
indices.stop_gradient = True
states_info, _ = helper.create_or_get_global_variable(
name="states_info",
persistable=True,
dtype='float32',
shape=[num_cls, 4])
states_info.stop_gradient = True
helper.set_variable_initializer(
states_info, Constant(
value=0.0, force_cpu=True))
batch_metrics, _ = helper.create_or_get_global_variable(
name="batch_metrics",
persistable=False,
dtype='float32',
shape=[6])
accum_metrics, _ = helper.create_or_get_global_variable(
name="global_metrics",
persistable=False,
dtype='float32',
shape=[6])
batch_states = fluid.layers.fill_constant(
shape=[num_cls, 4], value=0.0, dtype="float32")
batch_states.stop_gradient = True
helper.append_op(
type="precision_recall",
attrs={'class_number': num_cls},
inputs={
'MaxProbs': [max_probs],
'Indices': [indices],
'Labels': [label],
'StatesInfo': [states_info]
},
outputs={
'BatchMetrics': [batch_metrics],
'AccumMetrics': [accum_metrics],
'AccumStatesInfo': [batch_states]
})
helper.append_op(
type="assign",
inputs={'X': [batch_states]},
outputs={'Out': [states_info]})
batch_states.stop_gradient = True
states_info.stop_gradient = True
self._need_clear_list = [("states_info", "float32")]
self.metrics = dict()
self.metrics["precision_recall_f1"] = accum_metrics
self.metrics["accum_states"] = states_info
# self.metrics["batch_metrics"] = batch_metrics
def get_result(self):
return self.metrics
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddlerec.core.metrics import AUC
import paddle
import paddle.fluid as fluid
class TestAUC(unittest.TestCase):
def setUp(self):
self.ins_num = 64
self.batch_nums = 3
self.probs = np.random.uniform(0, 1.0,
(self.ins_num, 2)).astype('float32')
self.labels = np.random.choice(range(2), self.ins_num).reshape(
(self.ins_num, 1)).astype('int64')
self.place = fluid.core.CPUPlace()
self.num_thresholds = 2**12
python_auc = fluid.metrics.Auc(name="auc",
curve='ROC',
num_thresholds=self.num_thresholds)
for i in range(self.batch_nums):
python_auc.update(self.probs, self.labels)
self.auc = np.array(python_auc.eval())
def build_network(self):
predict = fluid.data(
name="predict", shape=[-1, 2], dtype='float32', lod_level=0)
label = fluid.data(
name="label", shape=[-1, 1], dtype='int64', lod_level=0)
auc = AUC(input=predict,
label=label,
num_thresholds=self.num_thresholds,
curve='ROC')
return auc
def test_forward(self):
precision_recall = self.build_network()
metrics = precision_recall.get_result()
fetch_vars = []
metric_keys = []
for item in metrics.items():
fetch_vars.append(item[1])
metric_keys.append(item[0])
exe = fluid.Executor(self.place)
exe.run(fluid.default_startup_program())
for i in range(self.batch_nums):
outs = exe.run(fluid.default_main_program(),
feed={'predict': self.probs,
'label': self.labels},
fetch_list=fetch_vars,
return_numpy=True)
outs = dict(zip(metric_keys, outs))
self.assertTrue(np.allclose(outs['AUC'], self.auc))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddlerec.core.metrics import PosNegRatio
import paddle
import paddle.fluid as fluid
class TestAUC(unittest.TestCase):
def setUp(self):
self.ins_num = 64
self.batch_nums = 3
self.probs = []
self.right_cnt = 0.0
self.wrong_cnt = 0.0
for i in range(self.batch_nums):
neg_score = np.random.uniform(0, 1.0,
(self.ins_num, 1)).astype('float32')
pos_score = np.random.uniform(0, 1.0,
(self.ins_num, 1)).astype('float32')
right_cnt = np.sum(np.less(neg_score, pos_score)).astype('int32')
wrong_cnt = np.sum(np.less_equal(pos_score, neg_score)).astype(
'int32')
self.right_cnt += float(right_cnt)
self.wrong_cnt += float(wrong_cnt)
self.probs.append((pos_score, neg_score))
self.place = fluid.core.CPUPlace()
def build_network(self):
pos_score = fluid.data(
name="pos_score", shape=[-1, 1], dtype='float32', lod_level=0)
neg_score = fluid.data(
name="neg_score", shape=[-1, 1], dtype='float32', lod_level=0)
pairwise_pn = PosNegRatio(pos_score=pos_score, neg_score=neg_score)
return pairwise_pn
def test_forward(self):
pn = self.build_network()
metrics = pn.get_result()
fetch_vars = []
metric_keys = []
for item in metrics.items():
fetch_vars.append(item[1])
metric_keys.append(item[0])
exe = fluid.Executor(self.place)
exe.run(fluid.default_startup_program())
for i in range(self.batch_nums):
outs = exe.run(fluid.default_main_program(),
feed={
'pos_score': self.probs[i][0],
'neg_score': self.probs[i][1]
},
fetch_list=fetch_vars,
return_numpy=True)
outs = dict(zip(metric_keys, outs))
self.assertTrue(np.allclose(outs['right_cnt'], self.right_cnt))
self.assertTrue(np.allclose(outs['wrong_cnt'], self.wrong_cnt))
self.assertTrue(
np.allclose(outs['pos_neg_ratio'],
np.array((self.right_cnt + 1.0) / (self.wrong_cnt + 1.0
))))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddlerec.core.metrics import PrecisionRecall
import paddle
import paddle.fluid as fluid
def calc_precision(tp_count, fp_count):
if tp_count > 0.0 or fp_count > 0.0:
return tp_count / (tp_count + fp_count)
return 1.0
def calc_recall(tp_count, fn_count):
if tp_count > 0.0 or fn_count > 0.0:
return tp_count / (tp_count + fn_count)
return 1.0
def calc_f1_score(precision, recall):
if precision > 0.0 or recall > 0.0:
return 2 * precision * recall / (precision + recall)
return 0.0
def get_states(idxs, labels, cls_num, weights=None, batch_nums=1):
ins_num = idxs.shape[0]
# TP FP TN FN
states = np.zeros((cls_num, 4)).astype('float32')
for i in range(ins_num):
w = weights[i] if weights is not None else 1.0
idx = idxs[i][0]
label = labels[i][0]
if idx == label:
states[idx][0] += w
for j in range(cls_num):
states[j][2] += w
states[idx][2] -= w
else:
states[label][3] += w
states[idx][1] += w
for j in range(cls_num):
states[j][2] += w
states[label][2] -= w
states[idx][2] -= w
return states
def compute_metrics(states, cls_num):
total_tp_count = 0.0
total_fp_count = 0.0
total_fn_count = 0.0
macro_avg_precision = 0.0
macro_avg_recall = 0.0
for i in range(cls_num):
total_tp_count += states[i][0]
total_fp_count += states[i][1]
total_fn_count += states[i][3]
macro_avg_precision += calc_precision(states[i][0], states[i][1])
macro_avg_recall += calc_recall(states[i][0], states[i][3])
metrics = []
macro_avg_precision /= cls_num
macro_avg_recall /= cls_num
metrics.append(macro_avg_precision)
metrics.append(macro_avg_recall)
metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall))
micro_avg_precision = calc_precision(total_tp_count, total_fp_count)
metrics.append(micro_avg_precision)
micro_avg_recall = calc_recall(total_tp_count, total_fn_count)
metrics.append(micro_avg_recall)
metrics.append(calc_f1_score(micro_avg_precision, micro_avg_recall))
return np.array(metrics).astype('float32')
class TestPrecisionRecall(unittest.TestCase):
def setUp(self):
self.ins_num = 64
self.cls_num = 10
self.batch_nums = 3
self.datas = []
self.states = np.zeros((self.cls_num, 4)).astype('float32')
for i in range(self.batch_nums):
probs = np.random.uniform(0, 1.0, (self.ins_num,
self.cls_num)).astype('float32')
idxs = np.array(np.argmax(
probs, axis=1)).reshape(self.ins_num, 1).astype('int32')
labels = np.random.choice(range(self.cls_num),
self.ins_num).reshape(
(self.ins_num, 1)).astype('int32')
self.datas.append((probs, labels))
states = get_states(idxs, labels, self.cls_num)
self.states = np.add(self.states, states)
self.metrics = compute_metrics(self.states, self.cls_num)
self.place = fluid.core.CPUPlace()
def build_network(self):
predict = fluid.data(
name="predict",
shape=[-1, self.cls_num],
dtype='float32',
lod_level=0)
label = fluid.data(
name="label", shape=[-1, 1], dtype='int32', lod_level=0)
precision_recall = PrecisionRecall(
input=predict, label=label, class_num=self.cls_num)
return precision_recall
def test_forward(self):
precision_recall = self.build_network()
metrics = precision_recall.get_result()
fetch_vars = []
metric_keys = []
for item in metrics.items():
fetch_vars.append(item[1])
metric_keys.append(item[0])
exe = fluid.Executor(self.place)
exe.run(fluid.default_startup_program())
for i in range(self.batch_nums):
outs = exe.run(
fluid.default_main_program(),
feed={'predict': self.datas[i][0],
'label': self.datas[i][1]},
fetch_list=fetch_vars,
return_numpy=True)
outs = dict(zip(metric_keys, outs))
self.assertTrue(np.allclose(outs['accum_states'], self.states))
self.assertTrue(np.allclose(outs['precision_recall_f1'], self.metrics))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddlerec.core.metrics import RecallK
import paddle
import paddle.fluid as fluid
class TestRecallK(unittest.TestCase):
def setUp(self):
self.ins_num = 64
self.cls_num = 10
self.topk = 2
self.batch_nums = 3
self.datas = []
self.match_num = 0.0
for i in range(self.batch_nums):
z = np.random.uniform(0, 1.0, (self.ins_num,
self.cls_num)).astype('float32')
pred = np.exp(z) / sum(np.exp(z))
label = np.random.choice(range(self.cls_num),
self.ins_num).reshape(
(self.ins_num, 1)).astype('int64')
self.datas.append((pred, label))
max_k_preds = pred.argsort(
axis=1)[:, -self.topk:][:, ::-1] #top-k label
match_array = np.logical_or.reduce(max_k_preds == label, axis=1)
self.match_num += np.sum(match_array).astype('float32')
self.place = fluid.core.CPUPlace()
def build_network(self):
pred = fluid.data(
name="pred",
shape=[-1, self.cls_num],
dtype='float32',
lod_level=0)
label = fluid.data(
name="label", shape=[-1, 1], dtype='int64', lod_level=0)
recall_k = RecallK(input=pred, label=label, k=self.topk)
return recall_k
def test_forward(self):
net = self.build_network()
metrics = net.get_result()
fetch_vars = []
metric_keys = []
for item in metrics.items():
fetch_vars.append(item[1])
metric_keys.append(item[0])
exe = fluid.Executor(self.place)
exe.run(fluid.default_startup_program())
for i in range(self.batch_nums):
outs = exe.run(
fluid.default_main_program(),
feed={'pred': self.datas[i][0],
'label': self.datas[i][1]},
fetch_list=fetch_vars,
return_numpy=True)
outs = dict(zip(metric_keys, outs))
self.assertTrue(
np.allclose(outs['ins_cnt'], self.ins_num * self.batch_nums))
self.assertTrue(np.allclose(outs['pos_cnt'], self.match_num))
self.assertTrue(
np.allclose(outs['Recall@%d_ACC' % (self.topk)],
np.array(self.match_num / (self.ins_num *
self.batch_nums))))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册