From b29129395a48abe67e1730d2ff4cca0c267715b7 Mon Sep 17 00:00:00 2001 From: cifar10 <41565156+cifar10@users.noreply.github.com> Date: Tue, 21 Jun 2022 15:06:46 +0800 Subject: [PATCH] add mlu arg_max kernel (#43624) --- paddle/fluid/operators/arg_max_op_mlu.cc | 112 +++++ .../unittests/mlu/test_arg_max_op_mlu.py | 388 ++++++++++++++++++ 2 files changed, 500 insertions(+) create mode 100644 paddle/fluid/operators/arg_max_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_arg_max_op_mlu.py diff --git a/paddle/fluid/operators/arg_max_op_mlu.cc b/paddle/fluid/operators/arg_max_op_mlu.cc new file mode 100644 index 00000000000..f3fae7591ac --- /dev/null +++ b/paddle/fluid/operators/arg_max_op_mlu.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class ArgMaxMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + auto axis = static_cast(ctx.Attr("axis")); + auto dtype = ctx.Attr("dtype"); + const bool& flatten = ctx.Attr("flatten"); + + if (x->numel() == 0) return; + PADDLE_ENFORCE_EQ( + (dtype == 2 || dtype == 3), true, + platform::errors::InvalidArgument( + "The attribute of dtype in argmax op must be [%s] or [%s], " + "but " + "received [%s]", + paddle::framework::DataTypeToString( + framework::proto::VarType::INT64), + paddle::framework::DataTypeToString( + framework::proto::VarType::INT32), + paddle::framework::DataTypeToString( + static_cast(dtype)))); + + if (axis < 0) { + framework::DDim x_dims; + x_dims = x->dims(); + axis += x_dims.size(); + } + + framework::Tensor flatten_x(x->type()); + flatten_x.ShareDataWith(*x); + if (flatten) { + flatten_x.Resize(phi::make_ddim({x->numel()})); + // if flatten, the axis just as 0 + axis = 0; + } + std::vector reduce_dims; + reduce_dims.push_back(axis); + + auto out_dims = out->dims(); + int out_count = out_dims[0]; + for (int i = 1; i < out_dims.size(); i++) { + out_count = out_count * out_dims[i]; + } + size_t indices_size_inbytes = out_count * sizeof(int32_t); + auto& dev_ctx = ctx.template device_context(); + framework::Tensor value_out = + ctx.AllocateTmpTensor(out->dims(), dev_ctx); + MLUCnnlTensorDesc value_out_desc(value_out); + MLUCnnlTensorDesc input_desc(flatten_x, CNNL_LAYOUT_ARRAY, + ToCnnlDataType(flatten_x.dtype())); + MLUCnnlReduceDesc reduction_desc( + reduce_dims, CNNL_REDUCE_MAX_LAST_INDEX, ToCnnlDataType(), + CNNL_NOT_PROPAGATE_NAN, CNNL_REDUCE_ONLY_INDICES, CNNL_32BIT_INDICES); + + if (dtype == 2) { + out->template mutable_data(ctx.GetPlace()); + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(), + nullptr, input_desc.get(), GetBasePtr(&flatten_x), + indices_size_inbytes /*indices_size*/, GetBasePtr(out), + nullptr, value_out_desc.get(), GetBasePtr(&value_out)); + } else { + out->template mutable_data(ctx.GetPlace()); + framework::Tensor out_int32 = + ctx.AllocateTmpTensor(out->dims(), + dev_ctx); + MLUCnnl::Reduce(ctx, true /*need_workspace*/, reduction_desc.get(), + nullptr, input_desc.get(), GetBasePtr(&flatten_x), + indices_size_inbytes /*indices_size*/, + GetBasePtr(&out_int32), nullptr, value_out_desc.get(), + GetBasePtr(&value_out)); + + // cast indices type to int64 + MLUCnnlTensorDesc out_int32_desc(out_int32); + MLUCnnlTensorDesc cast_output_desc(*out); + cnnlCastDataType_t cast_type = GetCastDataType(VT::INT32, VT::INT64); + MLUCnnl::Cast(ctx, cast_type, out_int32_desc.get(), + GetBasePtr(&out_int32), cast_output_desc.get(), + GetBasePtr(out)); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_MLU_KERNEL(arg_max, ops::ArgMaxMLUKernel, + ops::ArgMaxMLUKernel, + ops::ArgMaxMLUKernel); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_arg_max_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_arg_max_op_mlu.py new file mode 100644 index 00000000000..bd943e05b2d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_arg_max_op_mlu.py @@ -0,0 +1,388 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys + +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard + +paddle.enable_static() + + +class BaseTestCase(OpTest): + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + + def setUp(self): + self.set_mlu() + self.initTestCase() + self.x = (1000 * np.random.random(self.dims)).astype(self.dtype) + self.inputs = {'X': self.x} + self.attrs = {'axis': self.axis} + self.outputs = {'Out': np.argmax(self.x, axis=self.axis)} + + def test_check_output(self): + self.check_output_with_place(self.place) + + +# test argmax, dtype: float16 +class TestArgMaxFloat16Case1(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float16' + self.axis = -1 + + +class TestArgMaxFloat16Case2(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float16' + self.axis = 0 + + +class TestArgMaxFloat16Case3(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float16' + self.axis = 1 + + +class TestArgMaxFloat16Case4(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float16' + self.axis = 2 + + +class TestArgMaxFloat16Case5(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float16' + self.axis = -1 + + +class TestArgMaxFloat16Case6(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float16' + self.axis = 0 + + +class TestArgMaxFloat16Case7(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float16' + self.axis = 1 + + +class TestArgMaxFloat16Case8(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (1, ) + self.dtype = 'float16' + self.axis = 0 + + +class TestArgMaxFloat16Case9(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (2, ) + self.dtype = 'float16' + self.axis = 0 + + +class TestArgMaxFloat16Case10(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, ) + self.dtype = 'float16' + self.axis = 0 + + +# test argmax, dtype: float32 +class TestArgMaxFloat32Case1(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = -1 + + +class TestArgMaxFloat32Case2(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxFloat32Case3(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 1 + + +class TestArgMaxFloat32Case4(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 2 + + +class TestArgMaxFloat32Case5(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = -1 + + +class TestArgMaxFloat32Case6(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxFloat32Case7(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'float32' + self.axis = 1 + + +class TestArgMaxFloat32Case8(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (1, ) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxFloat32Case9(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (2, ) + self.dtype = 'float32' + self.axis = 0 + + +class TestArgMaxFloat32Case10(BaseTestCase): + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, ) + self.dtype = 'float32' + self.axis = 0 + + +class BaseTestComplex1_1(OpTest): + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (4, 5, 6) + self.dtype = 'float32' + self.axis = 2 + + def setUp(self): + self.set_mlu() + self.initTestCase() + self.x = (np.random.random(self.dims)).astype(self.dtype) + self.inputs = {'X': self.x} + self.attrs = { + 'axis': self.axis, + 'dtype': int(core.VarDesc.VarType.INT32) + } + self.outputs = { + 'Out': np.argmax(self.x, axis=self.axis).astype("int32") + } + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class BaseTestComplex1_2(OpTest): + + def set_mlu(self): + self.__class__.use_mlu = True + self.place = paddle.MLUPlace(0) + + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (4, 5, 6) + self.dtype = 'float16' + self.axis = 2 + + def setUp(self): + self.set_mlu() + self.initTestCase() + self.x = (np.random.random(self.dims)).astype(self.dtype) + self.inputs = {'X': self.x} + self.attrs = { + 'axis': self.axis, + 'dtype': int(core.VarDesc.VarType.INT32) + } + self.outputs = { + 'Out': np.argmax(self.x, axis=self.axis).astype("int32") + } + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestArgMaxAPI(unittest.TestCase): + + def initTestCase(self): + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + + def setUp(self): + self.initTestCase() + self.__class__.use_mlu = True + self.place = [paddle.MLUPlace(0)] + + def test_dygraph_api(self): + + def run(place): + paddle.disable_static(place) + np.random.seed(2022) + numpy_input = (np.random.random(self.dims)).astype(self.dtype) + tensor_input = paddle.to_tensor(numpy_input) + numpy_output = np.argmax(numpy_input, axis=self.axis) + paddle_output = paddle.argmax(tensor_input, axis=self.axis) + self.assertEqual(np.allclose(numpy_output, paddle_output.numpy()), + True) + paddle.enable_static() + + for place in self.place: + run(place) + + +class TestArgMaxAPI_2(unittest.TestCase): + + def initTestCase(self): + self.dims = (3, 4, 5) + self.dtype = 'float32' + self.axis = 0 + self.keep_dims = True + + def setUp(self): + self.initTestCase() + self.__class__.use_mlu = True + self.place = [paddle.MLUPlace(0)] + + def test_dygraph_api(self): + + def run(place): + paddle.disable_static(place) + np.random.seed(2022) + numpy_input = (np.random.random(self.dims)).astype(self.dtype) + tensor_input = paddle.to_tensor(numpy_input) + numpy_output = np.argmax(numpy_input, + axis=self.axis).reshape(1, 4, 5) + paddle_output = paddle.argmax(tensor_input, + axis=self.axis, + keepdim=self.keep_dims) + self.assertEqual(np.allclose(numpy_output, paddle_output.numpy()), + True) + self.assertEqual(numpy_output.shape, paddle_output.numpy().shape) + paddle.enable_static() + + for place in self.place: + run(place) + + +class TestArgMaxAPI_3(unittest.TestCase): + + def initTestCase(self): + self.dims = (1, 9) + self.dtype = 'float32' + + def setUp(self): + self.initTestCase() + self.__class__.use_mlu = True + self.place = [paddle.MLUPlace(0)] + + def test_dygraph_api(self): + + def run(place): + paddle.disable_static(place) + np.random.seed(2022) + numpy_input = (np.random.random(self.dims)).astype(self.dtype) + tensor_input = paddle.to_tensor(numpy_input) + numpy_output = np.argmax(numpy_input).reshape([1]) + paddle_output = paddle.argmax(tensor_input) + self.assertEqual(np.allclose(numpy_output, paddle_output.numpy()), + True) + self.assertEqual(numpy_output.shape, paddle_output.numpy().shape) + paddle.enable_static() + + for place in self.place: + run(place) + + +if __name__ == '__main__': + unittest.main() -- GitLab