From 1a80b484ad27be2598f0b13f12886ce885689e07 Mon Sep 17 00:00:00 2001 From: fuyou765 <64373205+fuyou765@users.noreply.github.com> Date: Thu, 9 Jun 2022 18:38:26 +0800 Subject: [PATCH] [MLU]add mlu kernel for range op (#43296) --- paddle/fluid/operators/range_op_mlu.cc | 72 ++++++++++++++ .../tests/unittests/mlu/test_range_op_mlu.py | 93 +++++++++++++++++++ 2 files changed, 165 insertions(+) create mode 100644 paddle/fluid/operators/range_op_mlu.cc create mode 100644 python/paddle/fluid/tests/unittests/mlu/test_range_op_mlu.py diff --git a/paddle/fluid/operators/range_op_mlu.cc b/paddle/fluid/operators/range_op_mlu.cc new file mode 100644 index 00000000000..ceeb0cf5c36 --- /dev/null +++ b/paddle/fluid/operators/range_op_mlu.cc @@ -0,0 +1,72 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/range_op.h" + +namespace paddle { +namespace operators { + +template +class RangeMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* start_t = context.Input("Start"); + auto* end_t = context.Input("End"); + auto* step_t = context.Input("Step"); + auto* out = context.Output("Out"); + + framework::Tensor n; + framework::TensorCopy( + *start_t, platform::CPUPlace(), + context.template device_context(), &n); + context.template device_context() + .Wait(); + T start = n.data()[0]; + framework::TensorCopy( + *end_t, platform::CPUPlace(), + context.template device_context(), &n); + context.template device_context() + .Wait(); + T end = n.data()[0]; + framework::TensorCopy( + *step_t, platform::CPUPlace(), + context.template device_context(), &n); + context.template device_context() + .Wait(); + T step = n.data()[0]; + + int64_t size = 0; + GetSize(start, end, step, &size); + + out->Resize(phi::make_ddim({size})); + out->mutable_data(context.GetPlace()); + + std::vector odata; + T value = start; + for (int64_t i = 0; i < size; ++i) { + odata.push_back(value); + value += step; + } + + framework::TensorFromVector(odata, context.device_context(), out); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OP_MLU_KERNEL(range, paddle::operators::RangeMLUKernel, + paddle::operators::RangeMLUKernel, + paddle::operators::RangeMLUKernel, + paddle::operators::RangeMLUKernel) diff --git a/python/paddle/fluid/tests/unittests/mlu/test_range_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_range_op_mlu.py new file mode 100644 index 00000000000..f87bd2e85da --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_range_op_mlu.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys + +sys.path.append("..") +import paddle +import unittest +import numpy as np +from op_test import OpTest +from functools import partial + +paddle.enable_static() + + +def arange_wrapper(start, end, step, dtype=None): + return paddle.arange(start, end, step, dtype) + + +class TestRangeOp(OpTest): + + def setUp(self): + self.op_type = "range" + self.place = paddle.device.MLUPlace(0) + self.__class__.use_mlu = True + self.init_config() + self.inputs = { + 'Start': np.array([self.case[0]]).astype(self.dtype), + 'End': np.array([self.case[1]]).astype(self.dtype), + 'Step': np.array([self.case[2]]).astype(self.dtype) + } + + self.outputs = { + 'Out': + np.arange(self.case[0], self.case[1], + self.case[2]).astype(self.dtype) + } + + def init_config(self): + self.dtype = np.float32 + self.python_api = partial(arange_wrapper, dtype=self.dtype) + self.case = (0, 1, 0.2) + + def test_check_output(self): + self.check_output_with_place(self.place, check_eager=False) + + +class TestFloatRangeOpCase0(TestRangeOp): + + def init_config(self): + self.dtype = np.float32 + self.python_api = partial(arange_wrapper, dtype=self.dtype) + self.case = (0, 5, 1) + + +class TestInt32RangeOpCase0(TestRangeOp): + + def init_config(self): + self.dtype = np.int32 + self.python_api = partial(arange_wrapper, dtype=self.dtype) + self.case = (0, 5, 2) + + +class TestInt32RangeOpCase1(TestRangeOp): + + def init_config(self): + self.dtype = np.int32 + self.python_api = partial(arange_wrapper, dtype=self.dtype) + self.case = (10, 1, -2) + + +class TestInt32RangeOpCase2(TestRangeOp): + + def init_config(self): + self.dtype = np.int32 + self.python_api = partial(arange_wrapper, dtype=self.dtype) + self.case = (-1, -10, -2) + + +if __name__ == "__main__": + unittest.main() -- GitLab