From d876952deb009376d58eea87a6f79b2e9c44b848 Mon Sep 17 00:00:00 2001 From: zhaoying9105 Date: Tue, 21 Jun 2022 14:58:00 +0800 Subject: [PATCH] [MLU]: add argsort/argsort_grad kernel (#43574) --- paddle/fluid/operators/argsort_op_mlu.cc | 109 ++++++++++++++++++ paddle/fluid/operators/gather_op_mlu.cc | 6 +- paddle/fluid/operators/mlu/mlu_baseop.cc | 15 ++- paddle/fluid/operators/mlu/mlu_baseop.h | 11 +- .../unittests/mlu/test_argsort_op_mlu.py | 88 ++++++++++++++ 5 files changed, 224 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/operators/argsort_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_argsort_op_mlu.py diff --git a/paddle/fluid/operators/argsort_op_mlu.cc b/paddle/fluid/operators/argsort_op_mlu.cc new file mode 100644 index 0000000000..1db97c1e01 --- /dev/null +++ b/paddle/fluid/operators/argsort_op_mlu.cc @@ -0,0 +1,109 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/mlu/mlu_baseop.h" + +namespace paddle { +namespace operators { + +template +class ArgsortMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto* indices = ctx.Output("Indices"); + const auto& place = ctx.GetPlace(); + + const auto& sorted = true; + const bool descending = ctx.Attr("descending"); + + // axis < 0, cacluate the real axis + int axis = static_cast(ctx.Attr("axis")); + if (axis < 0) { + const auto& in_dims = input->dims(); + axis += in_dims.size(); + } + + auto in_dims = input->dims(); + size_t k = in_dims[axis]; + + output->mutable_data(place); + indices->mutable_data(place); + + // cnnl only support int32/int16 type of indices + framework::Tensor indices_int32(framework::TransToPhiDataType(VT::INT32)); + indices_int32.Resize(indices->dims()); + indices_int32.mutable_data(place); + + MLUCnnlTensorDesc input_desc(*input); + MLUCnnlTensorDesc values_output_desc(*output); + MLUCnnlTensorDesc indices_int32_desc(indices_int32); + MLUCnnl::TopK(ctx, k, axis, descending, sorted, input_desc.get(), + GetBasePtr(input), values_output_desc.get(), + GetBasePtr(output), indices_int32_desc.get(), + GetBasePtr(&indices_int32)); + + // cast indices type to int64 + MLUCnnlTensorDesc cast_output_desc(*indices); + cnnlCastDataType_t cast_type = GetCastDataType(VT::INT32, VT::INT64); + MLUCnnl::Cast(ctx, cast_type, indices_int32_desc.get(), + GetBasePtr(&indices_int32), cast_output_desc.get(), + GetBasePtr(indices)); + } +}; + +template +class ArgsortGradMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* indices = ctx.Input("Indices"); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dout = ctx.Input(framework::GradVarName("Out")); + int axis = ctx.Attr("axis"); + dx->mutable_data(ctx.GetPlace()); + + auto in_dims = indices->dims(); + axis = (axis < 0) ? (in_dims.size() + axis) : axis; + if (dout->numel() == 0) return; + + MLUCnnlTensorDesc dout_desc(*dout); + MLUCnnlTensorDesc indices_desc(*indices); + MLUCnnlTensorDesc dx_desc(*dx); + MLUCnnl::ScatterFunctor(ctx, dx_desc.get(), GetBasePtr(dx), dout_desc.get(), + GetBasePtr(dout), indices_desc.get(), + GetBasePtr(indices), axis); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_MLU_KERNEL(argsort, + ops::ArgsortMLUKernel, + ops::ArgsortMLUKernel, + ops::ArgsortMLUKernel, + ops::ArgsortMLUKernel, + ops::ArgsortMLUKernel, + ops::ArgsortMLUKernel); + +REGISTER_OP_MLU_KERNEL(argsort_grad, + ops::ArgsortGradMLUKernel, + ops::ArgsortGradMLUKernel, + ops::ArgsortGradMLUKernel, + ops::ArgsortGradMLUKernel, + ops::ArgsortGradMLUKernel, + ops::ArgsortGradMLUKernel); diff --git a/paddle/fluid/operators/gather_op_mlu.cc b/paddle/fluid/operators/gather_op_mlu.cc index cf35e051ed..7f1592c148 100644 --- a/paddle/fluid/operators/gather_op_mlu.cc +++ b/paddle/fluid/operators/gather_op_mlu.cc @@ -91,9 +91,9 @@ class GatherGradOpMLUKernel : public framework::OpKernel { ToCnnlDataType(index->dtype())); MLUCnnlTensorDesc dout_desc(*dout); const cnnlScatterRefMode_t mode = CNNL_SCATTERREF_UPDATE; - MLUCnnl::ScatterFunctor(ctx, dx_desc.get(), GetBasePtr(dx), dout_desc.get(), - GetBasePtr(dout), index_desc.get(), - GetBasePtr(index), mode); + MLUCnnl::ScatterRefFunctor(ctx, dx_desc.get(), GetBasePtr(dx), + dout_desc.get(), GetBasePtr(dout), + index_desc.get(), GetBasePtr(index), mode); } }; diff --git a/paddle/fluid/operators/mlu/mlu_baseop.cc b/paddle/fluid/operators/mlu/mlu_baseop.cc index 445a9fecff..d9626ea20c 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.cc +++ b/paddle/fluid/operators/mlu/mlu_baseop.cc @@ -1134,7 +1134,7 @@ MLUCnnlDCNDesc::~MLUCnnlDCNDesc() { indices_desc, indices, output_desc, output)); } -/* static */ void MLUCnnl::ScatterFunctor( +/* static */ void MLUCnnl::ScatterRefFunctor( const ExecutionContext& ctx, const cnnlTensorDescriptor_t params_desc, const void* params, const cnnlTensorDescriptor_t updates_desc, const void* updates, const cnnlTensorDescriptor_t indices_desc, @@ -1146,6 +1146,19 @@ MLUCnnlDCNDesc::~MLUCnnlDCNDesc() { updates, 0, mode)); } +/* static */ void MLUCnnl::ScatterFunctor( + const ExecutionContext& ctx, const cnnlTensorDescriptor_t params_desc, + void* params, const cnnlTensorDescriptor_t updates_desc, + const void* updates, const cnnlTensorDescriptor_t indices_desc, + const void* indices, const int dim, const cnnlScatterMode_t mode) { + cnnlHandle_t handle = GetHandleFromCTX(ctx); + + PADDLE_ENFORCE_MLU_SUCCESS(cnnlScatter( + handle, dim, params_desc, params, indices_desc, indices, updates_desc, + updates, params_desc, params, /* output_desc, output, same with params*/ + mode)); +} + /* static */ void MLUCnnl::StridedSliceGrad( const ExecutionContext& ctx, const int begin[], const int end[], const int strides[], const cnnlTensorDescriptor_t input_desc, diff --git a/paddle/fluid/operators/mlu/mlu_baseop.h b/paddle/fluid/operators/mlu/mlu_baseop.h index 4ba7ae5ac6..3eacfca407 100644 --- a/paddle/fluid/operators/mlu/mlu_baseop.h +++ b/paddle/fluid/operators/mlu/mlu_baseop.h @@ -626,12 +626,21 @@ class MLUCnnl { const cnnlTensorDescriptor_t indices_desc, const void* indices, const cnnlTensorDescriptor_t output_desc, void* output); - static void ScatterFunctor( + static void ScatterRefFunctor( const ExecutionContext& ctx, const cnnlTensorDescriptor_t params_desc, const void* params, const cnnlTensorDescriptor_t updates_desc, const void* updates, const cnnlTensorDescriptor_t indices_desc, const void* indices, const cnnlScatterRefMode_t mode); + static void ScatterFunctor(const ExecutionContext& ctx, + const cnnlTensorDescriptor_t params_desc, + const void* params, + const cnnlTensorDescriptor_t updates_desc, + const void* updates, + const cnnlTensorDescriptor_t indices_desc, + const void* indices, const int dim, + const cnnlScatterMode_t mode = CNNL_SCATTER); + static void Range(const ExecutionContext& ctx, const void* start, const void* end, const void* step, const cnnlDataType_t output_dtype, void* output); diff --git a/python/paddle/fluid/tests/unittests/mlu/test_argsort_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_argsort_op_mlu.py new file mode 100644 index 0000000000..8bfce1fe08 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_argsort_op_mlu.py @@ -0,0 +1,88 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import paddle +import numpy as np +import unittest +import sys + +sys.path.append("..") +from op_test import OpTest + +paddle.enable_static() + +SEED = 2022 + + +def gen_test_class(dtype, axis, descending): + + class TestArgsortOp(OpTest): + + def setUp(self): + np.random.seed(SEED) + self.set_mlu() + self.op_type = "argsort" + self.place = paddle.MLUPlace(0) + self.init_inputshape() + if 'int' in dtype: + self.x = np.random.choice(255, self.size, replace=False) + self.x = self.x.reshape(self.input_shape).astype(dtype) + else: + self.x = np.random.random(self.input_shape).astype(dtype) + self.inputs = {"X": self.x} + self.attrs = {"axis": axis, "descending": descending} + self.get_output() + self.outputs = {"Out": self.sorted_x, "Indices": self.indices} + + def get_output(self): + if descending: + self.indices = np.flip( + np.argsort(self.x, kind='heapsort', axis=axis), axis) + self.sorted_x = np.flip( + np.sort(self.x, kind='heapsort', axis=axis), axis) + else: + self.indices = np.argsort(self.x, kind='heapsort', axis=axis) + self.sorted_x = np.sort(self.x, kind='heapsort', axis=axis) + + def test_check_grad(self): + if dtype in ['float16', 'int8', 'uint8', 'int32']: + self.__class__.no_need_check_grad = True + else: + self.check_grad_with_place(self.place, ["X"], "Out") + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_inputshape(self): + self.input_shape = (5, 2, 2, 3, 3) + self.size = np.prod(self.input_shape) + + def test_check_output(self): + self.check_output_with_place(self.place) + + def init_direction(self): + self.descending = False + + cls_name = "{}_{}_{}_TestArgsortOp".format(dtype, axis, descending) + TestArgsortOp.__name__ = cls_name + globals()[cls_name] = TestArgsortOp + + +for dtype in ['float32', 'float16', 'int8', 'uint8', 'int32']: + for axis in [1, 2, 3, -1]: + for descending in [False]: + gen_test_class(dtype, axis, descending) +if __name__ == '__main__': + unittest.main() -- GitLab