From e674af23bd6e05df6b04afd7f78eae785d30f5be Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Fri, 18 Feb 2022 11:33:01 +0800 Subject: [PATCH] refactor the forward implementation of shape npu op (#39613) --- paddle/fluid/operators/shape_op_npu.cc | 33 +++++++++---------- .../tests/unittests/npu/test_shape_op_npu.py | 20 +++++++++++ 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/shape_op_npu.cc b/paddle/fluid/operators/shape_op_npu.cc index 89a1e952d1d..7bff7b2d668 100644 --- a/paddle/fluid/operators/shape_op_npu.cc +++ b/paddle/fluid/operators/shape_op_npu.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,9 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include - #include "paddle/fluid/operators/shape_op.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" @@ -27,20 +24,20 @@ template class ShapeNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* in_var = ctx.InputVar("Input"); - framework::DDim in_dims; - if (in_var->IsType()) { - in_dims = in_var->Get().value().dims(); - } else { - in_dims = in_var->Get().dims(); - } + auto* x = ctx.Input("Input"); auto* out_t = ctx.Output("Out"); - out_t->Resize({in_dims.size()}); - // to do: cpuplace? - auto out_data = out_t->mutable_data(platform::CPUPlace()); - for (int i = 0; i < in_dims.size(); ++i) { - out_data[i] = in_dims[i]; - } + out_t->Resize({x->dims().size()}); + out_t->mutable_data(ctx.GetPlace()); + + // The output data type defaults to int32. + auto stream = + ctx.template device_context() + .stream(); + NpuOpRunner runner; + auto dst_dtype = ConvertToNpuDtype(framework::proto::VarType::INT32); + runner.SetType("Shape").AddInput(*x).AddOutput(*out_t).AddAttr( + "dtype", static_cast(dst_dtype)); + runner.Run(stream); } }; @@ -55,5 +52,7 @@ REGISTER_OP_NPU_KERNEL( ops::ShapeNPUKernel, ops::ShapeNPUKernel, ops::ShapeNPUKernel, + ops::ShapeNPUKernel, ops::ShapeNPUKernel, ops::ShapeNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_shape_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_shape_op_npu.py index cb1b0c458fc..0adfb69cd63 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_shape_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_shape_op_npu.py @@ -51,5 +51,25 @@ class TestShape(OpTest): self.check_output_with_place(self.place) +class TestShape_fp16(TestShape): + def init_dtype(self): + self.dtype = np.float16 + + +class TestShape_double(TestShape): + def init_dtype(self): + self.dtype = np.float64 + + +class TestShape_int32(TestShape): + def init_dtype(self): + self.dtype = np.int32 + + +class TestShape_int64(TestShape): + def init_dtype(self): + self.dtype = np.int64 + + if __name__ == '__main__': unittest.main() -- GitLab