diff --git a/paddle/fluid/operators/flatten_op_npu.cc b/paddle/fluid/operators/flatten_op_npu.cc index 3fe1f0419190d0ab636209a11aaba2ba57582035..385dad530d95df3f54bf44e0f99713f3747a0a04 100644 --- a/paddle/fluid/operators/flatten_op_npu.cc +++ b/paddle/fluid/operators/flatten_op_npu.cc @@ -34,6 +34,26 @@ class Flatten2NPUKernel : public framework::OpKernel { runner.Run(stream); } }; + +template +class Flatten2GradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + + auto xshape_dims = ctx.Input("XShape")->dims(); + auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); + d_x->Resize(x_dims); + } +}; + } // namespace operators } // namespace paddle @@ -45,3 +65,9 @@ REGISTER_OP_NPU_KERNEL(flatten2, ops::Flatten2NPUKernel, ops::Flatten2NPUKernel, ops::Flatten2NPUKernel, ops::Flatten2NPUKernel); +REGISTER_OP_NPU_KERNEL(flatten2_grad, ops::Flatten2GradNPUKernel, + ops::Flatten2GradNPUKernel, + ops::Flatten2GradNPUKernel, + ops::Flatten2GradNPUKernel, + ops::Flatten2GradNPUKernel, + ops::Flatten2GradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_flatten2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_flatten2_op_npu.py old mode 100644 new mode 100755 index 0fc0d1b7ac49110a79cc078d7debf0b53feb24e1..acd7ca770164e524178a594452c57c571ed93b29 --- a/python/paddle/fluid/tests/unittests/npu/test_flatten2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_flatten2_op_npu.py @@ -43,6 +43,9 @@ class TestFlatten2Op(OpTest): def test_check_output(self): self.check_output_with_place(self.place, no_check_set=["XShape"]) + def test_check_grad(self): + self.check_grad_with_place(self.place, ["X"], "Out") + def init_test_case(self): self.in_shape = (3, 2, 4, 5) self.axis = 1