From 0759e99d8a4ba233850dbffe87954a2b6a628776 Mon Sep 17 00:00:00 2001 From: helen88 Date: Mon, 18 Apr 2022 11:39:36 +0800 Subject: [PATCH] support tril_triu_grad for KL2, *test=kunlun (#41877) --- paddle/fluid/operators/tril_triu_op_xpu.cc | 32 +++++++++++++++++++ .../fluid/platform/device/xpu/xpu2_op_list.h | 3 ++ .../unittests/xpu/test_tril_triu_op_xpu.py | 18 ++++++++--- 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/tril_triu_op_xpu.cc b/paddle/fluid/operators/tril_triu_op_xpu.cc index a44ea8ff68..70200fe733 100644 --- a/paddle/fluid/operators/tril_triu_op_xpu.cc +++ b/paddle/fluid/operators/tril_triu_op_xpu.cc @@ -43,6 +43,34 @@ class TrilTriuXPUKernel : public framework::OpKernel { } }; +template +class TrilTriuGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const auto* d_out = + context.Input(framework::GradVarName("Out")); + const auto* dout_data = d_out->data(); + auto* d_x = context.Output(framework::GradVarName("X")); + auto* dx_data = d_x->mutable_data(context.GetPlace()); + + const int diagonal = context.Attr("diagonal"); + const bool lower = context.Attr("lower"); + + auto dy_shape = phi::vectorize(d_out->dims()); + auto& dev_ctx = context.template device_context(); + int r = 0; + if (lower) { + r = xpu::tril(dev_ctx.x_context(), dout_data, dx_data, dy_shape, + diagonal); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "tril_op"); + } else { + r = xpu::triu(dev_ctx.x_context(), dout_data, dx_data, dy_shape, + diagonal); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "triu_op"); + } + } +}; + } // namespace operators } // namespace paddle @@ -50,4 +78,8 @@ namespace ops = paddle::operators; REGISTER_OP_XPU_KERNEL( tril_triu, ops::TrilTriuXPUKernel, ops::TrilTriuXPUKernel); +REGISTER_OP_XPU_KERNEL( + tril_triu_grad, + ops::TrilTriuGradXPUKernel, + ops::TrilTriuGradXPUKernel); #endif diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 6f4826bd8c..7b88f261d5 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -380,6 +380,9 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::FP16, XPUPlace())})}, {"tril_triu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace())})}, + {"tril_triu_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace())})}, {"tile", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()), diff --git a/python/paddle/fluid/tests/unittests/xpu/test_tril_triu_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_tril_triu_op_xpu.py index fb6b28d9c2..ee689efbb3 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_tril_triu_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_tril_triu_op_xpu.py @@ -42,6 +42,7 @@ class XPUTestTrilTriuOp(XPUOpTestWrapper): self.real_np_op = getattr(np, self.real_op_type) self.set_xpu() self.op_type = "tril_triu" + self.place = paddle.XPUPlace(0) if self.dtype == np.int32: self.X = np.arange( 1, self.get_Xshape_prod() + 1, @@ -69,13 +70,22 @@ class XPUTestTrilTriuOp(XPUOpTestWrapper): def set_xpu(self): self.__class__.use_xpu = True - self.__class__.no_need_check_grad = True + self.__class__.no_need_check_grad = False self.__class__.op_type = self.real_op_type def test_check_output(self): - if paddle.is_compiled_with_xpu(): - place = paddle.XPUPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(self.place) + + def test_check_grad_normal(self): + if self.dtype == np.int32: + user_defined_grad_outputs = np.random.random( + self.Xshape).astype('float32') + self.check_grad_with_place( + self.place, ['X'], + 'Out', + user_defined_grad_outputs=user_defined_grad_outputs) + else: + self.check_grad_with_place(self.place, ['X'], 'Out') def initTestCase(self): self.diagonal = None -- GitLab