diff --git a/paddle/fluid/operators/shape_op_mlu.cc b/paddle/fluid/operators/shape_op_mlu.cc new file mode 100644 index 0000000000000000000000000000000000000000..28d2c0146f7fc489558fa0d69219de6139195706 --- /dev/null +++ b/paddle/fluid/operators/shape_op_mlu.cc @@ -0,0 +1,68 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_MLU +#include + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using SelectedRows = phi::SelectedRows; + +template +class ShapeMLUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* in_var = ctx.InputVar("Input"); + framework::DDim in_dims; + if (in_var->IsType()) { + in_dims = in_var->Get().value().dims(); + } else { + in_dims = in_var->Get().dims(); + } + auto* out_t = ctx.Output("Out"); + out_t->Resize({in_dims.size()}); + out_t->mutable_data(ctx.GetPlace()); + + // shape op cpu + Tensor shape_on_cpu( + framework::TransToPhiDataType(framework::proto::VarType::INT32)); + shape_on_cpu.Resize({in_dims.size()}); + auto cpu_data = shape_on_cpu.mutable_data(platform::CPUPlace()); + for (int i = 0; i < in_dims.size(); ++i) { + cpu_data[i] = in_dims[i]; + } + + // cpu to mlu + auto& dev_ctx = ctx.template device_context(); + framework::TensorCopy(shape_on_cpu, ctx.GetPlace(), dev_ctx, out_t); + dev_ctx.Wait(); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_MLU_KERNEL(shape, ops::ShapeMLUKernel, + ops::ShapeMLUKernel, + ops::ShapeMLUKernel, ops::ShapeMLUKernel, + ops::ShapeMLUKernel, + ops::ShapeMLUKernel, + ops::ShapeMLUKernel, ops::ShapeMLUKernel); + +#endif diff --git a/python/paddle/fluid/tests/unittests/mlu/test_shape_op_mlu.py b/python/paddle/fluid/tests/unittests/mlu/test_shape_op_mlu.py new file mode 100644 index 0000000000000000000000000000000000000000..57a95d50a5e90a4fc60667e588a9e88074f5929b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mlu/test_shape_op_mlu.py @@ -0,0 +1,100 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys + +sys.path.append("..") +from op_test import OpTest +import paddle +# from paddle.fluid import core +# from paddle.fluid.op import Operator + +paddle.enable_static() +SEED = 2022 + + +class TestShape(OpTest): + + def setUp(self): + self.set_mlu() + self.op_type = "shape" + self.place = paddle.MLUPlace(0) + + self.init_dtype() + np.random.seed(SEED) + x = np.random.uniform(1, 2, [5, 10]).astype(self.dtype) + out = np.array([5, 10]) + + self.inputs = {'Input': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {} + self.outputs = {'Out': out} + + def set_mlu(self): + self.__class__.use_mlu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place) + + +class TestShape_fp16(TestShape): + + def init_dtype(self): + self.dtype = np.float16 + + +class TestShape_double(TestShape): + + def init_dtype(self): + self.dtype = np.float64 + + +class TestShape_int32(TestShape): + + def init_dtype(self): + self.dtype = np.int32 + + +class TestShape_int64(TestShape): + + def init_dtype(self): + self.dtype = np.int64 + + +class TestShape_int8(TestShape): + + def init_dtype(self): + self.dtype = np.int8 + + +class TestShape_uint8(TestShape): + + def init_dtype(self): + self.dtype = np.uint8 + + +class TestShape_bool(TestShape): + + def init_dtype(self): + self.dtype = bool + + +if __name__ == '__main__': + unittest.main()