From 7afd31bb78023bccf1d4d64d9d9b3ece8aaca48a Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Mon, 9 Aug 2021 15:17:58 +0800 Subject: [PATCH] [NPU] Support npu op flatten2_grad (#34669) --- paddle/fluid/operators/flatten_op_npu.cc | 26 +++++++++++++++++++ .../unittests/npu/test_flatten2_op_npu.py | 3 +++ 2 files changed, 29 insertions(+) mode change 100644 => 100755 python/paddle/fluid/tests/unittests/npu/test_flatten2_op_npu.py diff --git a/paddle/fluid/operators/flatten_op_npu.cc b/paddle/fluid/operators/flatten_op_npu.cc index 3fe1f041919..385dad530d9 100644 --- a/paddle/fluid/operators/flatten_op_npu.cc +++ b/paddle/fluid/operators/flatten_op_npu.cc @@ -34,6 +34,26 @@ class Flatten2NPUKernel : public framework::OpKernel { runner.Run(stream); } }; + +template +class Flatten2GradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + + auto xshape_dims = ctx.Input("XShape")->dims(); + auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); + d_x->Resize(x_dims); + } +}; + } // namespace operators } // namespace paddle @@ -45,3 +65,9 @@ REGISTER_OP_NPU_KERNEL(flatten2, ops::Flatten2NPUKernel, ops::Flatten2NPUKernel, ops::Flatten2NPUKernel, ops::Flatten2NPUKernel); +REGISTER_OP_NPU_KERNEL(flatten2_grad, ops::Flatten2GradNPUKernel, + ops::Flatten2GradNPUKernel, + ops::Flatten2GradNPUKernel, + ops::Flatten2GradNPUKernel, + ops::Flatten2GradNPUKernel, + ops::Flatten2GradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_flatten2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_flatten2_op_npu.py old mode 100644 new mode 100755 index 0fc0d1b7ac4..acd7ca77016 --- a/python/paddle/fluid/tests/unittests/npu/test_flatten2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_flatten2_op_npu.py @@ -43,6 +43,9 @@ class TestFlatten2Op(OpTest): def test_check_output(self): self.check_output_with_place(self.place, no_check_set=["XShape"]) + def test_check_grad(self): + self.check_grad_with_place(self.place, ["X"], "Out") + def init_test_case(self): self.in_shape = (3, 2, 4, 5) self.axis = 1 -- GitLab