diff --git a/paddle/fluid/operators/flatten_op_npu.cc b/paddle/fluid/operators/flatten_op_npu.cc index 1569760fe3b96fcfc3d49aa0a92d3199722d8529..9252716f3acfc1295281ca56646c1a354a1668ef 100644 --- a/paddle/fluid/operators/flatten_op_npu.cc +++ b/paddle/fluid/operators/flatten_op_npu.cc @@ -78,6 +78,25 @@ class FlattenContiguousRangeNPUKernel : public framework::OpKernel { } }; +template +class FlattenContiguousRangeGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + + auto xshape_dims = ctx.Input("XShape")->dims(); + auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopy( + *d_out, ctx.GetPlace(), + ctx.template device_context(), d_x); + d_x->Resize(x_dims); + } +}; + } // namespace operators } // namespace paddle @@ -110,3 +129,17 @@ REGISTER_OP_NPU_KERNEL( int8_t>, ops::FlattenContiguousRangeNPUKernel); +REGISTER_OP_NPU_KERNEL( + flatten_contiguous_range_grad, + ops::FlattenContiguousRangeGradNPUKernel, + ops::FlattenContiguousRangeGradNPUKernel, + ops::FlattenContiguousRangeGradNPUKernel, + ops::FlattenContiguousRangeGradNPUKernel, + ops::FlattenContiguousRangeGradNPUKernel, + ops::FlattenContiguousRangeGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py old mode 100644 new mode 100755 index 88e711dcf068e6a03410ce9b01b7e455df38329b..742d156c7f5f1bc60cc959ec10e40618b7906b3d --- a/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_flatten_contiguous_range_op_npu.py @@ -49,7 +49,7 @@ class TestFlattenOp(OpTest): self.check_output_with_place(self.place, no_check_set=["XShape"]) def test_check_grad(self): - pass + self.check_grad_with_place(self.place, ["X"], "Out") def init_test_case(self): self.in_shape = (3, 2, 5, 4) @@ -163,13 +163,13 @@ class TestFlattenOp_Float32(TestFlattenOp): } -class TestFlattenOp_int(TestFlattenOp): +class TestFlattenOp_int32(TestFlattenOp): def init_test_case(self): self.in_shape = (3, 2, 5, 4) self.start_axis = 0 self.stop_axis = 1 self.new_shape = (6, 5, 4) - self.dtype = np.int + self.dtype = np.int32 def init_attrs(self): self.attrs = { @@ -177,6 +177,9 @@ class TestFlattenOp_int(TestFlattenOp): "stop_axis": self.stop_axis } + def test_check_grad(self): + pass + class TestFlattenOp_uint8(TestFlattenOp): def init_test_case(self): @@ -192,6 +195,9 @@ class TestFlattenOp_uint8(TestFlattenOp): "stop_axis": self.stop_axis } + def test_check_grad(self): + pass + class TestFlattenOp_int8(TestFlattenOp): def init_test_case(self): @@ -207,6 +213,9 @@ class TestFlattenOp_int8(TestFlattenOp): "stop_axis": self.stop_axis } + def test_check_grad(self): + pass + class TestFlattenOp_int64(TestFlattenOp): def init_test_case(self): @@ -222,6 +231,9 @@ class TestFlattenOp_int64(TestFlattenOp): "stop_axis": self.stop_axis } + def test_check_grad(self): + pass + class TestFlatten2OpError(unittest.TestCase): def test_errors(self):