From 70dc5f49b705337798e272a5dd22aed665165c3a Mon Sep 17 00:00:00 2001 From: pangyoki Date: Fri, 21 May 2021 10:35:11 +0800 Subject: [PATCH] [NPU] cast indices and label if their type is not consistent in accuracy npu op (#33016) * cast indices and label if their type is not consistent * fix bug * add unittest --- .../operators/metrics/accuracy_op_npu.cc | 32 ++++++++++++- .../unittests/npu/test_accuracy_op_npu.py | 48 +++++++++++++++++++ 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/metrics/accuracy_op_npu.cc b/paddle/fluid/operators/metrics/accuracy_op_npu.cc index 9c5e157a97..c18b8590db 100644 --- a/paddle/fluid/operators/metrics/accuracy_op_npu.cc +++ b/paddle/fluid/operators/metrics/accuracy_op_npu.cc @@ -39,12 +39,42 @@ class AccuracyNPUKernel : public framework::OpKernel { return; } + // cast `indices` or `label` if their type is not consistent + Tensor cast_indices(framework::proto::VarType::INT32); + Tensor cast_label(framework::proto::VarType::INT32); + if (indices->type() != label->type()) { + auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32); + if (indices->type() != framework::proto::VarType::INT32) { + cast_indices.Resize(indices->dims()); + cast_indices.mutable_data(ctx.GetPlace()); + auto runner_cast_indices = + NpuOpRunner("Cast", {*indices}, {cast_indices}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_indices.Run(stream); + } else { + cast_indices.ShareDataWith(*indices); + } + if (label->type() != framework::proto::VarType::INT32) { + cast_label.Resize(label->dims()); + cast_label.mutable_data(ctx.GetPlace()); + auto runner_cast_label = + NpuOpRunner("Cast", {*label}, {cast_label}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_label.Run(stream); + } else { + cast_label.ShareDataWith(*label); + } + } else { + cast_indices.ShareDataWith(*indices); + cast_label.ShareDataWith(*label); + } + // equal Tensor tmp_equal(framework::proto::VarType::BOOL); tmp_equal.Resize(inference->dims()); tmp_equal.mutable_data(ctx.GetPlace()); auto runner_equal = - NpuOpRunner("Equal", {*indices, *label}, {tmp_equal}, {}); + NpuOpRunner("Equal", {cast_indices, cast_label}, {tmp_equal}, {}); runner_equal.Run(stream); // cast equal diff --git a/python/paddle/fluid/tests/unittests/npu/test_accuracy_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_accuracy_op_npu.py index aa22863983..5aeca5abd9 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_accuracy_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_accuracy_op_npu.py @@ -87,5 +87,53 @@ class TestAccuracy2(TestAccuracy): } +class TestAccuracyType(TestAccuracy): + def setUp(self): + self.op_type = "accuracy" + self.set_npu() + self.init_dtype() + np.random.seed(SEED) + n = 8192 + infer = np.random.random((n, 100)).astype(self.dtype) + indices = np.random.randint(0, 1000, (n, 100)).astype('int64') + label = np.random.randint(0, 1000, (n, 1)).astype('int32') + self.inputs = {'Out': infer, 'Indices': indices, "Label": label} + num_correct = 0 + for rowid in range(n): + for ele in indices[rowid]: + if ele == label[rowid]: + num_correct += 1 + break + self.outputs = { + 'Accuracy': np.array([num_correct / float(n)]).astype(self.dtype), + 'Correct': np.array([num_correct]).astype("int32"), + 'Total': np.array([n]).astype("int32") + } + + +class TestAccuracyType2(TestAccuracy): + def setUp(self): + self.op_type = "accuracy" + self.set_npu() + self.init_dtype() + np.random.seed(SEED) + n = 8192 + infer = np.random.random((n, 100)).astype(self.dtype) + indices = np.random.randint(0, 1000, (n, 100)).astype('int32') + label = np.random.randint(0, 1000, (n, 1)).astype('int64') + self.inputs = {'Out': infer, 'Indices': indices, "Label": label} + num_correct = 0 + for rowid in range(n): + for ele in indices[rowid]: + if ele == label[rowid]: + num_correct += 1 + break + self.outputs = { + 'Accuracy': np.array([num_correct / float(n)]).astype(self.dtype), + 'Correct': np.array([num_correct]).astype("int32"), + 'Total': np.array([n]).astype("int32") + } + + if __name__ == '__main__': unittest.main() -- GitLab