未验证 提交 1587ad07 编写于 作者: T TTerror 提交者: GitHub

update reduce_max for kunlun, *test=kunlun (#42116)

上级 6700294c
...@@ -105,11 +105,10 @@ class ReduceMaxGradXPUKernel : public framework::OpKernel<T> { ...@@ -105,11 +105,10 @@ class ReduceMaxGradXPUKernel : public framework::OpKernel<T> {
" wrong value[%d %s].", " wrong value[%d %s].",
r, XPUAPIErrorMsg[r])); r, XPUAPIErrorMsg[r]));
// step 2. comparse out_brocast and x // step 2. comparse out_brocast and x
r = xpu::elementwise_equal<T>(dev_ctx.x_context(), x_data, brocast1, equal, r = xpu::equal<T>(dev_ctx.x_context(), x_data, brocast1, equal, x->numel());
x->numel());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r == xpu::Error_t::SUCCESS, true, r == xpu::Error_t::SUCCESS, true,
platform::errors::External("XPU elementwise_equal in reduce_max_grad " platform::errors::External("XPU equal in reduce_max_grad "
"op return wrong value[%d %s].", "op return wrong value[%d %s].",
r, XPUAPIErrorMsg[r])); r, XPUAPIErrorMsg[r]));
// step 3. get x_grad // step 3. get x_grad
......
...@@ -57,6 +57,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -57,6 +57,9 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})}, pOpKernelType(vartype::INT32, XPUPlace())})},
{"check_finite_and_unscale",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册