未验证 提交 70dc5f49 编写于 作者: P pangyoki 提交者: GitHub

[NPU] cast indices and label if their type is not consistent in accuracy npu op (#33016)

* cast indices and label if their type is not consistent

* fix bug

* add unittest
上级 a96e8bc9
......@@ -39,12 +39,42 @@ class AccuracyNPUKernel : public framework::OpKernel<T> {
return;
}
// cast `indices` or `label` if their type is not consistent
Tensor cast_indices(framework::proto::VarType::INT32);
Tensor cast_label(framework::proto::VarType::INT32);
if (indices->type() != label->type()) {
auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32);
if (indices->type() != framework::proto::VarType::INT32) {
cast_indices.Resize(indices->dims());
cast_indices.mutable_data<int>(ctx.GetPlace());
auto runner_cast_indices =
NpuOpRunner("Cast", {*indices}, {cast_indices},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_indices.Run(stream);
} else {
cast_indices.ShareDataWith(*indices);
}
if (label->type() != framework::proto::VarType::INT32) {
cast_label.Resize(label->dims());
cast_label.mutable_data<int>(ctx.GetPlace());
auto runner_cast_label =
NpuOpRunner("Cast", {*label}, {cast_label},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_label.Run(stream);
} else {
cast_label.ShareDataWith(*label);
}
} else {
cast_indices.ShareDataWith(*indices);
cast_label.ShareDataWith(*label);
}
// equal
Tensor tmp_equal(framework::proto::VarType::BOOL);
tmp_equal.Resize(inference->dims());
tmp_equal.mutable_data<bool>(ctx.GetPlace());
auto runner_equal =
NpuOpRunner("Equal", {*indices, *label}, {tmp_equal}, {});
NpuOpRunner("Equal", {cast_indices, cast_label}, {tmp_equal}, {});
runner_equal.Run(stream);
// cast equal
......
......@@ -87,5 +87,53 @@ class TestAccuracy2(TestAccuracy):
}
class TestAccuracyType(TestAccuracy):
def setUp(self):
self.op_type = "accuracy"
self.set_npu()
self.init_dtype()
np.random.seed(SEED)
n = 8192
infer = np.random.random((n, 100)).astype(self.dtype)
indices = np.random.randint(0, 1000, (n, 100)).astype('int64')
label = np.random.randint(0, 1000, (n, 1)).astype('int32')
self.inputs = {'Out': infer, 'Indices': indices, "Label": label}
num_correct = 0
for rowid in range(n):
for ele in indices[rowid]:
if ele == label[rowid]:
num_correct += 1
break
self.outputs = {
'Accuracy': np.array([num_correct / float(n)]).astype(self.dtype),
'Correct': np.array([num_correct]).astype("int32"),
'Total': np.array([n]).astype("int32")
}
class TestAccuracyType2(TestAccuracy):
def setUp(self):
self.op_type = "accuracy"
self.set_npu()
self.init_dtype()
np.random.seed(SEED)
n = 8192
infer = np.random.random((n, 100)).astype(self.dtype)
indices = np.random.randint(0, 1000, (n, 100)).astype('int32')
label = np.random.randint(0, 1000, (n, 1)).astype('int64')
self.inputs = {'Out': infer, 'Indices': indices, "Label": label}
num_correct = 0
for rowid in range(n):
for ele in indices[rowid]:
if ele == label[rowid]:
num_correct += 1
break
self.outputs = {
'Accuracy': np.array([num_correct / float(n)]).astype(self.dtype),
'Correct': np.array([num_correct]).astype("int32"),
'Total': np.array([n]).astype("int32")
}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册