未验证 提交 92253f11 编写于 作者: R ronnywang 提交者: GitHub

fix paddle.where broadcast bug (#39182)

上级 99cfcc09
......@@ -585,26 +585,49 @@ def where(condition, x=None, y=None, name=None):
condition_shape = list(condition.shape)
x_shape = list(x.shape)
y_shape = list(y.shape)
if x_shape == y_shape and condition_shape == x_shape:
if in_dygraph_mode():
return _C_ops.where(condition, x, y)
else:
helper = LayerHelper("where", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='where',
inputs={'Condition': condition,
'X': x,
'Y': y},
outputs={'Out': [out]})
broadcast_condition = condition
broadcast_x = x
broadcast_y = y
else:
if core.is_compiled_with_xpu():
cond_int = layers.cast(condition, x.dtype)
cond_not_int = layers.cast(layers.logical_not(condition), x.dtype)
out1 = layers.elementwise_mul(x, cond_int)
out2 = layers.elementwise_mul(y, cond_not_int)
out = layers.elementwise_add(out1, out2)
return out
zeros_like_x = layers.zeros_like(x)
zeros_like_y = layers.zeros_like(y)
zeros_like_condition = layers.zeros_like(condition)
zeros_like_condition = layers.cast(zeros_like_condition, x.dtype)
cast_cond = layers.cast(condition, x.dtype)
broadcast_zeros = layers.elementwise_add(zeros_like_x, zeros_like_y)
broadcast_zeros = layers.elementwise_add(broadcast_zeros,
zeros_like_condition)
broadcast_x = layers.elementwise_add(x, broadcast_zeros)
broadcast_y = layers.elementwise_add(y, broadcast_zeros)
broadcast_condition = layers.elementwise_add(cast_cond, broadcast_zeros)
broadcast_condition = layers.cast(broadcast_condition, 'bool')
if in_dygraph_mode():
return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y)
else:
cond_int = layers.cast(condition, x.dtype)
cond_not_int = layers.cast(layers.logical_not(condition), x.dtype)
out1 = layers.elementwise_mul(x, cond_int)
out2 = layers.elementwise_mul(y, cond_not_int)
out = layers.elementwise_add(out1, out2)
helper = LayerHelper("where", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='where',
inputs={
'Condition': broadcast_condition,
'X': broadcast_x,
'Y': broadcast_y
},
outputs={'Out': [out]})
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册