未验证 提交 fc537d4f 编写于 作者: F Fan Zhang 提交者: GitHub

[NPU] Support npu op flatten_contiguous_range_grad (#34798)

上级 234c21ac
...@@ -78,6 +78,25 @@ class FlattenContiguousRangeNPUKernel : public framework::OpKernel<T> { ...@@ -78,6 +78,25 @@ class FlattenContiguousRangeNPUKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T>
class FlattenContiguousRangeGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_x = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *d_out =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto xshape_dims = ctx.Input<framework::LoDTensor>("XShape")->dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopy(
*d_out, ctx.GetPlace(),
ctx.template device_context<paddle::platform::NPUDeviceContext>(), d_x);
d_x->Resize(x_dims);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -110,3 +129,17 @@ REGISTER_OP_NPU_KERNEL( ...@@ -110,3 +129,17 @@ REGISTER_OP_NPU_KERNEL(
int8_t>, int8_t>,
ops::FlattenContiguousRangeNPUKernel<paddle::platform::NPUDeviceContext, ops::FlattenContiguousRangeNPUKernel<paddle::platform::NPUDeviceContext,
int64_t>); int64_t>);
REGISTER_OP_NPU_KERNEL(
flatten_contiguous_range_grad,
ops::FlattenContiguousRangeGradNPUKernel<paddle::platform::NPUDeviceContext,
float>,
ops::FlattenContiguousRangeGradNPUKernel<paddle::platform::NPUDeviceContext,
double>,
ops::FlattenContiguousRangeGradNPUKernel<paddle::platform::NPUDeviceContext,
uint8_t>,
ops::FlattenContiguousRangeGradNPUKernel<paddle::platform::NPUDeviceContext,
int>,
ops::FlattenContiguousRangeGradNPUKernel<paddle::platform::NPUDeviceContext,
int8_t>,
ops::FlattenContiguousRangeGradNPUKernel<paddle::platform::NPUDeviceContext,
int64_t>);
...@@ -49,7 +49,7 @@ class TestFlattenOp(OpTest): ...@@ -49,7 +49,7 @@ class TestFlattenOp(OpTest):
self.check_output_with_place(self.place, no_check_set=["XShape"]) self.check_output_with_place(self.place, no_check_set=["XShape"])
def test_check_grad(self): def test_check_grad(self):
pass self.check_grad_with_place(self.place, ["X"], "Out")
def init_test_case(self): def init_test_case(self):
self.in_shape = (3, 2, 5, 4) self.in_shape = (3, 2, 5, 4)
...@@ -163,13 +163,13 @@ class TestFlattenOp_Float32(TestFlattenOp): ...@@ -163,13 +163,13 @@ class TestFlattenOp_Float32(TestFlattenOp):
} }
class TestFlattenOp_int(TestFlattenOp): class TestFlattenOp_int32(TestFlattenOp):
def init_test_case(self): def init_test_case(self):
self.in_shape = (3, 2, 5, 4) self.in_shape = (3, 2, 5, 4)
self.start_axis = 0 self.start_axis = 0
self.stop_axis = 1 self.stop_axis = 1
self.new_shape = (6, 5, 4) self.new_shape = (6, 5, 4)
self.dtype = np.int self.dtype = np.int32
def init_attrs(self): def init_attrs(self):
self.attrs = { self.attrs = {
...@@ -177,6 +177,9 @@ class TestFlattenOp_int(TestFlattenOp): ...@@ -177,6 +177,9 @@ class TestFlattenOp_int(TestFlattenOp):
"stop_axis": self.stop_axis "stop_axis": self.stop_axis
} }
def test_check_grad(self):
pass
class TestFlattenOp_uint8(TestFlattenOp): class TestFlattenOp_uint8(TestFlattenOp):
def init_test_case(self): def init_test_case(self):
...@@ -192,6 +195,9 @@ class TestFlattenOp_uint8(TestFlattenOp): ...@@ -192,6 +195,9 @@ class TestFlattenOp_uint8(TestFlattenOp):
"stop_axis": self.stop_axis "stop_axis": self.stop_axis
} }
def test_check_grad(self):
pass
class TestFlattenOp_int8(TestFlattenOp): class TestFlattenOp_int8(TestFlattenOp):
def init_test_case(self): def init_test_case(self):
...@@ -207,6 +213,9 @@ class TestFlattenOp_int8(TestFlattenOp): ...@@ -207,6 +213,9 @@ class TestFlattenOp_int8(TestFlattenOp):
"stop_axis": self.stop_axis "stop_axis": self.stop_axis
} }
def test_check_grad(self):
pass
class TestFlattenOp_int64(TestFlattenOp): class TestFlattenOp_int64(TestFlattenOp):
def init_test_case(self): def init_test_case(self):
...@@ -222,6 +231,9 @@ class TestFlattenOp_int64(TestFlattenOp): ...@@ -222,6 +231,9 @@ class TestFlattenOp_int64(TestFlattenOp):
"stop_axis": self.stop_axis "stop_axis": self.stop_axis
} }
def test_check_grad(self):
pass
class TestFlatten2OpError(unittest.TestCase): class TestFlatten2OpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册