未验证 提交 2cda4e21 编写于 作者: S Siming Dai 提交者: GitHub

Fix where xpu bug (#45832)

上级 b971ba04
...@@ -638,14 +638,6 @@ def where(condition, x=None, y=None, name=None): ...@@ -638,14 +638,6 @@ def where(condition, x=None, y=None, name=None):
broadcast_x = x broadcast_x = x
broadcast_y = y broadcast_y = y
else: else:
if core.is_compiled_with_xpu():
cond_int = paddle.cast(condition, x.dtype)
cond_not_int = paddle.cast(logical_not(condition), x.dtype)
out1 = paddle.multiply(x, cond_int)
out2 = paddle.multiply(y, cond_not_int)
out = paddle.add(out1, out2)
return out
zeros_like_x = paddle.zeros_like(x) zeros_like_x = paddle.zeros_like(x)
zeros_like_y = paddle.zeros_like(y) zeros_like_y = paddle.zeros_like(y)
zeros_like_condition = paddle.zeros_like(condition) zeros_like_condition = paddle.zeros_like(condition)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册