未验证 提交 0759e99d 编写于 作者: z8hanghuan's avatar z8hanghuan 提交者: GitHub

support tril_triu_grad for KL2, *test=kunlun (#41877)

上级 ceef73c9
......@@ -43,6 +43,34 @@ class TrilTriuXPUKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class TrilTriuGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
const auto* dout_data = d_out->data<T>();
auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dx_data = d_x->mutable_data<T>(context.GetPlace());
const int diagonal = context.Attr<int>("diagonal");
const bool lower = context.Attr<bool>("lower");
auto dy_shape = phi::vectorize<int>(d_out->dims());
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = 0;
if (lower) {
r = xpu::tril(dev_ctx.x_context(), dout_data, dx_data, dy_shape,
diagonal);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "tril_op");
} else {
r = xpu::triu(dev_ctx.x_context(), dout_data, dx_data, dy_shape,
diagonal);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "triu_op");
}
}
};
} // namespace operators
} // namespace paddle
......@@ -50,4 +78,8 @@ namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
tril_triu, ops::TrilTriuXPUKernel<paddle::platform::XPUDeviceContext, int>,
ops::TrilTriuXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
tril_triu_grad,
ops::TrilTriuGradXPUKernel<paddle::platform::XPUDeviceContext, int>,
ops::TrilTriuGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif
......@@ -380,6 +380,9 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::FP16, XPUPlace())})},
{"tril_triu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"tril_triu_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"tile", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
......
......@@ -42,6 +42,7 @@ class XPUTestTrilTriuOp(XPUOpTestWrapper):
self.real_np_op = getattr(np, self.real_op_type)
self.set_xpu()
self.op_type = "tril_triu"
self.place = paddle.XPUPlace(0)
if self.dtype == np.int32:
self.X = np.arange(
1, self.get_Xshape_prod() + 1,
......@@ -69,13 +70,22 @@ class XPUTestTrilTriuOp(XPUOpTestWrapper):
def set_xpu(self):
self.__class__.use_xpu = True
self.__class__.no_need_check_grad = True
self.__class__.no_need_check_grad = False
self.__class__.op_type = self.real_op_type
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(self.place)
def test_check_grad_normal(self):
if self.dtype == np.int32:
user_defined_grad_outputs = np.random.random(
self.Xshape).astype('float32')
self.check_grad_with_place(
self.place, ['X'],
'Out',
user_defined_grad_outputs=user_defined_grad_outputs)
else:
self.check_grad_with_place(self.place, ['X'], 'Out')
def initTestCase(self):
self.diagonal = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册