From 383de29565228ab0db2bb4f3455ce264d80567d1 Mon Sep 17 00:00:00 2001 From: fwenguang <95677191+fwenguang@users.noreply.github.com> Date: Thu, 10 Feb 2022 20:52:55 +0800 Subject: [PATCH] [MLU] add mlu kernel for accuracy op (#39337) * [MLU] add mlu kernel for accuracy op * fix license format * fix error message --- .../operators/metrics/accuracy_op_mlu.cc | 167 ++++++++++++++++++ .../unittests/mlu/test_accuracy_op_mlu.py | 136 ++++++++++++++ 2 files changed, 303 insertions(+) create mode 100644 paddle/fluid/operators/metrics/accuracy_op_mlu.cc create mode 100755 python/paddle/fluid/tests/unittests/mlu/test_accuracy_op_mlu.py diff --git a/paddle/fluid/operators/metrics/accuracy_op_mlu.cc b/paddle/fluid/operators/metrics/accuracy_op_mlu.cc new file mode 100644 index 00000000000..0649f9172ee --- /dev/null +++ b/paddle/fluid/operators/metrics/accuracy_op_mlu.cc @@ -0,0 +1,167 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/metrics/accuracy_op.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class AccuracyMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* indices = ctx.Input("Indices"); + auto* label = ctx.Input("Label"); + + auto* accuracy = ctx.Output("Accuracy"); + auto* correct = ctx.Output("Correct"); + auto* total = ctx.Output("Total"); + + int num_samples = indices->dims()[0]; + if (num_samples == 0) { + return; + } + + // cast `indices` or `label` if their type is not INT32 + Tensor indices_int32(VT::INT32); + Tensor label_int32(VT::INT32); + if (indices->type() != VT::INT32) { + PADDLE_ENFORCE_EQ(MLUSupportsCast(indices->type(), VT::INT32), true, + platform::errors::Unavailable( + "In accuracy mlu kernel, cast indices from [%s] to " + "[%s] is not supported.", + framework::DataTypeToString(indices->type()), + framework::DataTypeToString(VT::INT32))); + indices_int32.Resize(indices->dims()); + indices_int32.mutable_data(ctx.GetPlace()); + MLUCnnlTensorDesc org_indices_desc(*indices); + MLUCnnlTensorDesc indices_int32_desc(indices_int32); + cnnlCastDataType_t cast_type = + GetCastDataType(indices->type(), VT::INT32); + MLUCnnl::Cast(ctx, cast_type, org_indices_desc.get(), GetBasePtr(indices), + indices_int32_desc.get(), GetBasePtr(&indices_int32)); + } else { + indices_int32.ShareDataWith(*indices); + } + if (label->type() != VT::INT32) { + PADDLE_ENFORCE_EQ( + MLUSupportsCast(label->type(), VT::INT32), true, + platform::errors::Unavailable( + "In accuracy mlu kernel, cast label from [%s] to [%s] " + "is not supported.", + framework::DataTypeToString(label->type()), + framework::DataTypeToString(VT::INT32))); + label_int32.Resize(label->dims()); + label_int32.mutable_data(ctx.GetPlace()); + MLUCnnlTensorDesc org_label_desc(*label); + MLUCnnlTensorDesc label_int32_desc(label_int32); + cnnlCastDataType_t cast_type = GetCastDataType(label->type(), VT::INT32); + MLUCnnl::Cast(ctx, cast_type, org_label_desc.get(), GetBasePtr(label), + label_int32_desc.get(), GetBasePtr(&label_int32)); + } else { + label_int32.ShareDataWith(*label); + } + + // equal + MLUCnnlTensorDesc indices_int32_desc(indices_int32); + MLUCnnlTensorDesc label_int32_desc(label_int32); + Tensor equal_tensor(VT::BOOL); + equal_tensor.Resize(indices->dims()); + equal_tensor.mutable_data(ctx.GetPlace()); + MLUCnnlTensorDesc equal_tensor_desc(equal_tensor); + MLUCnnl::Logic(ctx, CNNL_LOGIC_OP_EQ, indices_int32_desc.get(), + GetBasePtr(&indices_int32), label_int32_desc.get(), + GetBasePtr(&label_int32), equal_tensor_desc.get(), + GetBasePtr(&equal_tensor)); + + // cast equal + Tensor equal_fp32(VT::FP32); + equal_fp32.Resize(indices->dims()); + equal_fp32.mutable_data(ctx.GetPlace()); + MLUCnnlTensorDesc equal_fp32_desc(equal_fp32); + cnnlCastDataType_t equal_cast_type = GetCastDataType(VT::BOOL, VT::FP32); + MLUCnnl::Cast(ctx, equal_cast_type, equal_tensor_desc.get(), + GetBasePtr(&equal_tensor), equal_fp32_desc.get(), + GetBasePtr(&equal_fp32)); + + // [correct] + // reduce_max + Tensor correct_max(VT::FP32); + correct_max.Resize(framework::make_ddim({num_samples})); + correct_max.mutable_data(ctx.GetPlace()); + MLUCnnlTensorDesc correct_max_desc(correct_max); + MLUCnnlReduceDesc reduce_max_desc( + {1}, CNNL_REDUCE_MAX, ToCnnlDataType(), CNNL_NOT_PROPAGATE_NAN, + CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduce_max_desc.get(), + nullptr, equal_fp32_desc.get(), GetBasePtr(&equal_fp32), + 0 /*indices_size*/, nullptr, nullptr, + correct_max_desc.get(), GetBasePtr(&correct_max)); + + // reduce_sum + Tensor correct_sum(VT::FP32); + correct_sum.Resize(correct->dims()); + correct_sum.mutable_data(ctx.GetPlace()); + MLUCnnlTensorDesc correct_sum_desc(correct_sum); + MLUCnnlReduceDesc reduce_sum_desc( + {0}, CNNL_REDUCE_ADD, ToCnnlDataType(), CNNL_NOT_PROPAGATE_NAN, + CNNL_REDUCE_NO_INDICES, CNNL_32BIT_INDICES); + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduce_sum_desc.get(), + nullptr, correct_max_desc.get(), GetBasePtr(&correct_max), + 0 /*indices_size*/, nullptr, nullptr, + correct_sum_desc.get(), GetBasePtr(&correct_sum)); + + // cast to int + correct->mutable_data(ctx.GetPlace()); + MLUCnnlTensorDesc correct_desc(*correct); + cnnlCastDataType_t correct_cast_type = GetCastDataType(VT::FP32, VT::INT32); + MLUCnnl::Cast(ctx, correct_cast_type, correct_sum_desc.get(), + GetBasePtr(&correct_sum), correct_desc.get(), + GetBasePtr(correct)); + + // [total] + total->mutable_data(ctx.GetPlace()); + MLUCnnlTensorDesc total_desc(*total); + MLUCnnl::Fill(ctx, num_samples, total_desc.get(), GetBasePtr(total)); + + // use `total` of type `float32` for calculating accuracy + Tensor total_fp32(VT::FP32); + total_fp32.Resize(total->dims()); + total_fp32.mutable_data(ctx.GetPlace()); + MLUCnnlTensorDesc total_fp32_desc(total_fp32); + MLUCnnl::Fill(ctx, static_cast(num_samples), total_fp32_desc.get(), + GetBasePtr(&total_fp32)); + + // [accuracy] + accuracy->mutable_data(ctx.GetPlace()); + MLUCnnlTensorDesc accuracy_desc(*accuracy); + MLUCnnl::Div(ctx, CNNL_COMPUTATION_HIGH_PRECISION, correct_sum_desc.get(), + GetBasePtr(&correct_sum), total_fp32_desc.get(), + GetBasePtr(&total_fp32), accuracy_desc.get(), + GetBasePtr(accuracy)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_MLU_KERNEL(accuracy, ops::AccuracyMLUKernel, + ops::AccuracyMLUKernel, + ops::AccuracyMLUKernel, + ops::AccuracyMLUKernel, + ops::AccuracyMLUKernel, + ops::AccuracyMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_accuracy_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_accuracy_op_mlu.py new file mode 100755 index 00000000000..e229966c12d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_accuracy_op_mlu.py @@ -0,0 +1,136 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append('..') +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard + + +class TestAccuracyOp(OpTest): + def setUp(self): + self.op_type = "accuracy" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.dtype = np.float32 + self.init_dtype() + n = 8192 + infer = np.random.random((n, 1)).astype(self.dtype) + indices = np.random.randint(0, 2, (n, 1)).astype('int32') + label = np.random.randint(0, 2, (n, 1)).astype('int32') + self.inputs = {'Out': infer, 'Indices': indices, "Label": label} + num_correct = 0 + for rowid in range(n): + for ele in indices[rowid]: + if ele == label[rowid]: + num_correct += 1 + break + self.outputs = { + 'Accuracy': np.array([num_correct / float(n)]).astype(self.dtype), + 'Correct': np.array([num_correct]).astype("int32"), + 'Total': np.array([n]).astype("int32") + } + + def init_dtype(self): + pass + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestAccuracyOpFp16(TestAccuracyOp): + def init_dtype(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +class TestAccuracyOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The input type of accuracy_op must be Variable. + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.MLUPlace(0)) + label = fluid.layers.data( + name='label', shape=[-1, 1], dtype="int32") + self.assertRaises(TypeError, fluid.layers.accuracy, x1, label) + self.assertRaises(TypeError, paddle.metric.accuracy, x1, label) + # The input dtype of accuracy_op must be float32 or float64. + x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32") + self.assertRaises(TypeError, fluid.layers.accuracy, x2, label) + self.assertRaises(TypeError, paddle.metric.accuracy, x2, label) + x3 = fluid.layers.data(name='input', shape=[-1, 2], dtype="float16") + fluid.layers.accuracy(input=x3, label=label) + paddle.metric.accuracy(input=x3, label=label) + + +class TestAccuracyAPI1(unittest.TestCase): + def setUp(self): + self.predictions = paddle.static.data( + shape=[2, 5], name="predictions", dtype="float32") + self.label = paddle.static.data( + shape=[2, 1], name="labels", dtype="int32") + self.result = paddle.static.accuracy( + input=self.predictions, label=self.label, k=1) + self.input_predictions = np.array( + [[0.2, 0.1, 0.4, 0.1, 0.1], [0.2, 0.3, 0.1, 0.15, 0.25]], + dtype="float32") + self.input_labels = np.array([[2], [0]], dtype="int32") + self.expect_value = np.array([0.5], dtype='float32') + + def test_api(self): + exe = paddle.static.Executor() + result, = exe.run(feed={ + "predictions": self.input_predictions, + 'labels': self.input_labels + }, + fetch_list=[self.result.name]) + self.assertEqual((result == self.expect_value).all(), True) + + +class TestAccuracyAPI2(unittest.TestCase): + def test_api(self): + with fluid.dygraph.guard(): + predictions = paddle.to_tensor( + [[0.2, 0.1, 0.4, 0.1, 0.1], [0.2, 0.3, 0.1, 0.15, 0.25]], + dtype='float32') + label = paddle.to_tensor([[2], [0]], dtype="int32") + result = paddle.static.accuracy(input=predictions, label=label, k=1) + expect_value = np.array([0.5], dtype='float32') + self.assertEqual((result.numpy() == expect_value).all(), True) + + +class TestAccuracyAPI(unittest.TestCase): + def test_api(self): + with fluid.dygraph.guard(): + predictions = paddle.to_tensor( + [[0.2, 0.1, 0.4, 0.1, 0.1], [0.2, 0.3, 0.1, 0.15, 0.25]], + dtype='float32') + label = paddle.to_tensor([[2], [0]], dtype="int32") + result = paddle.metric.accuracy(input=predictions, label=label, k=1) + expect_value = np.array([0.5], dtype='float32') + + self.assertEqual((result.numpy() == expect_value).all(), True) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() -- GitLab