未验证 提交 1587ad07 编写于 作者: T TTerror 提交者: GitHub

update reduce_max for kunlun, *test=kunlun (#42116)

上级 6700294c
......@@ -105,11 +105,10 @@ class ReduceMaxGradXPUKernel : public framework::OpKernel<T> {
" wrong value[%d %s].",
r, XPUAPIErrorMsg[r]));
// step 2. comparse out_brocast and x
r = xpu::elementwise_equal<T>(dev_ctx.x_context(), x_data, brocast1, equal,
x->numel());
r = xpu::equal<T>(dev_ctx.x_context(), x_data, brocast1, equal, x->numel());
PADDLE_ENFORCE_EQ(
r == xpu::Error_t::SUCCESS, true,
platform::errors::External("XPU elementwise_equal in reduce_max_grad "
platform::errors::External("XPU equal in reduce_max_grad "
"op return wrong value[%d %s].",
r, XPUAPIErrorMsg[r]));
// step 3. get x_grad
......
......@@ -57,6 +57,9 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"check_finite_and_unscale",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册