From 50bc11621f53d04285d786bd259279cd50d19ec0 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Thu, 18 Mar 2021 19:54:30 +0800 Subject: [PATCH] [NPU] fix reshape npu op kernel (#31726) * rename npu op file * fix reshape --- .../{reshape2_op_npu.cc => reshape_op_npu.cc} | 37 +++++++------------ ...hape2_op_npu.py => test_reshape_op_npu.py} | 0 2 files changed, 13 insertions(+), 24 deletions(-) rename paddle/fluid/operators/{reshape2_op_npu.cc => reshape_op_npu.cc} (75%) rename python/paddle/fluid/tests/unittests/npu/{test_reshape2_op_npu.py => test_reshape_op_npu.py} (100%) diff --git a/paddle/fluid/operators/reshape2_op_npu.cc b/paddle/fluid/operators/reshape_op_npu.cc similarity index 75% rename from paddle/fluid/operators/reshape2_op_npu.cc rename to paddle/fluid/operators/reshape_op_npu.cc index 7ca85abcf7a..cd0f0bb2558 100644 --- a/paddle/fluid/operators/reshape2_op_npu.cc +++ b/paddle/fluid/operators/reshape_op_npu.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/npu_op_runner.h" namespace paddle { @@ -25,34 +26,22 @@ class Reshape2NPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); - auto* shape = ctx.Attr>> ("shape"); auto* out = ctx.Output("Out"); - auto org_shape = framework::vectorize(x->dims()); - // reshape - int64_t shape_all = 1; - int64_t org_shape_all = 1; - int index = -1; - for (int i = 0; i < shape.size(); i++) { - if (shape[i] == 0) { - shape[i] = org_shape[i]; - } - if (shape[i] == -1) { - index = i; - } else { - shape_all *= shape[i]; - } - org_shape_all *= org_shape[i]; + auto list_new_shape_tensor = + ctx.MultiInput("ShapeTensor"); + if (list_new_shape_tensor.size() > 0) { + PADDLE_THROW(platform::errors::Unimplemented( + "Input(ShapeTensor) is not supported on NPU.")); } - - if (index >= 0) { - shape[index] = org_shape_all / shape_all; - } - out.Resize(framework::make_ddim(shape)); + PADDLE_ENFORCE_EQ(ctx.Input("Shape"), nullptr, + platform::errors::Unimplemented( + "Input(Shape) is not supported on NPU.")); + auto shape = out->dims(); out->mutable_data(ctx.GetPlace(), x->type()); framework::TensorCopy( *x, ctx.GetPlace(), ctx.template device_context(), out); - out.Resize(framework::make_ddim(shape)); + out->Resize(shape); } }; @@ -77,11 +66,11 @@ class Reshape2GradNPUKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_NPU_KERNEL( - reshpe2, ops::Reshape2NPUKernel, + reshape2, ops::Reshape2NPUKernel, ops::Reshape2NPUKernel); REGISTER_OP_NPU_KERNEL( - reshpe2_grad, + reshape2_grad, ops::Reshape2GradNPUKernel, ops::Reshape2GradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_reshape2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_reshape_op_npu.py similarity index 100% rename from python/paddle/fluid/tests/unittests/npu/test_reshape2_op_npu.py rename to python/paddle/fluid/tests/unittests/npu/test_reshape_op_npu.py -- GitLab