diff --git a/paddle/fluid/operators/reshape2_op_npu.cc b/paddle/fluid/operators/reshape_op_npu.cc similarity index 75% rename from paddle/fluid/operators/reshape2_op_npu.cc rename to paddle/fluid/operators/reshape_op_npu.cc index 7ca85abcf7afcfb5e0cd6dd8fb9fc249a105690b..cd0f0bb2558974ea8d3024be88a452c927912b8b 100644 --- a/paddle/fluid/operators/reshape2_op_npu.cc +++ b/paddle/fluid/operators/reshape_op_npu.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/npu_op_runner.h" namespace paddle { @@ -25,34 +26,22 @@ class Reshape2NPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* x = ctx.Input("X"); - auto* shape = ctx.Attr>> ("shape"); auto* out = ctx.Output("Out"); - auto org_shape = framework::vectorize(x->dims()); - // reshape - int64_t shape_all = 1; - int64_t org_shape_all = 1; - int index = -1; - for (int i = 0; i < shape.size(); i++) { - if (shape[i] == 0) { - shape[i] = org_shape[i]; - } - if (shape[i] == -1) { - index = i; - } else { - shape_all *= shape[i]; - } - org_shape_all *= org_shape[i]; + auto list_new_shape_tensor = + ctx.MultiInput("ShapeTensor"); + if (list_new_shape_tensor.size() > 0) { + PADDLE_THROW(platform::errors::Unimplemented( + "Input(ShapeTensor) is not supported on NPU.")); } - - if (index >= 0) { - shape[index] = org_shape_all / shape_all; - } - out.Resize(framework::make_ddim(shape)); + PADDLE_ENFORCE_EQ(ctx.Input("Shape"), nullptr, + platform::errors::Unimplemented( + "Input(Shape) is not supported on NPU.")); + auto shape = out->dims(); out->mutable_data(ctx.GetPlace(), x->type()); framework::TensorCopy( *x, ctx.GetPlace(), ctx.template device_context(), out); - out.Resize(framework::make_ddim(shape)); + out->Resize(shape); } }; @@ -77,11 +66,11 @@ class Reshape2GradNPUKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_NPU_KERNEL( - reshpe2, ops::Reshape2NPUKernel, + reshape2, ops::Reshape2NPUKernel, ops::Reshape2NPUKernel); REGISTER_OP_NPU_KERNEL( - reshpe2_grad, + reshape2_grad, ops::Reshape2GradNPUKernel, ops::Reshape2GradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_reshape2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_reshape_op_npu.py similarity index 100% rename from python/paddle/fluid/tests/unittests/npu/test_reshape2_op_npu.py rename to python/paddle/fluid/tests/unittests/npu/test_reshape_op_npu.py