未验证 提交 7afd31bb 编写于 作者: Y YuanRisheng 提交者: GitHub

[NPU] Support npu op flatten2_grad (#34669)

上级 3380778f
......@@ -34,6 +34,26 @@ class Flatten2NPUKernel : public framework::OpKernel<T> {
runner.Run(stream);
}
};
template <typename T>
class Flatten2GradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *d_x = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto *d_out =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto xshape_dims = ctx.Input<framework::LoDTensor>("XShape")->dims();
auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size());
d_x->mutable_data(ctx.GetPlace(), d_out->type());
framework::TensorCopy(
*d_out, ctx.GetPlace(),
ctx.template device_context<paddle::platform::NPUDeviceContext>(), d_x);
d_x->Resize(x_dims);
}
};
} // namespace operators
} // namespace paddle
......@@ -45,3 +65,9 @@ REGISTER_OP_NPU_KERNEL(flatten2, ops::Flatten2NPUKernel<float>,
ops::Flatten2NPUKernel<int>,
ops::Flatten2NPUKernel<int8_t>,
ops::Flatten2NPUKernel<int64_t>);
REGISTER_OP_NPU_KERNEL(flatten2_grad, ops::Flatten2GradNPUKernel<float>,
ops::Flatten2GradNPUKernel<double>,
ops::Flatten2GradNPUKernel<uint8_t>,
ops::Flatten2GradNPUKernel<int>,
ops::Flatten2GradNPUKernel<int8_t>,
ops::Flatten2GradNPUKernel<int64_t>);
......@@ -43,6 +43,9 @@ class TestFlatten2Op(OpTest):
def test_check_output(self):
self.check_output_with_place(self.place, no_check_set=["XShape"])
def test_check_grad(self):
self.check_grad_with_place(self.place, ["X"], "Out")
def init_test_case(self):
self.in_shape = (3, 2, 4, 5)
self.axis = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册