未验证 提交 92253f11 编写于 作者: R ronnywang 提交者: GitHub

fix paddle.where broadcast bug (#39182)

上级 99cfcc09
...@@ -585,26 +585,49 @@ def where(condition, x=None, y=None, name=None): ...@@ -585,26 +585,49 @@ def where(condition, x=None, y=None, name=None):
condition_shape = list(condition.shape) condition_shape = list(condition.shape)
x_shape = list(x.shape) x_shape = list(x.shape)
y_shape = list(y.shape) y_shape = list(y.shape)
if x_shape == y_shape and condition_shape == x_shape: if x_shape == y_shape and condition_shape == x_shape:
broadcast_condition = condition
broadcast_x = x
broadcast_y = y
else:
if core.is_compiled_with_xpu():
cond_int = layers.cast(condition, x.dtype)
cond_not_int = layers.cast(layers.logical_not(condition), x.dtype)
out1 = layers.elementwise_mul(x, cond_int)
out2 = layers.elementwise_mul(y, cond_not_int)
out = layers.elementwise_add(out1, out2)
return out
zeros_like_x = layers.zeros_like(x)
zeros_like_y = layers.zeros_like(y)
zeros_like_condition = layers.zeros_like(condition)
zeros_like_condition = layers.cast(zeros_like_condition, x.dtype)
cast_cond = layers.cast(condition, x.dtype)
broadcast_zeros = layers.elementwise_add(zeros_like_x, zeros_like_y)
broadcast_zeros = layers.elementwise_add(broadcast_zeros,
zeros_like_condition)
broadcast_x = layers.elementwise_add(x, broadcast_zeros)
broadcast_y = layers.elementwise_add(y, broadcast_zeros)
broadcast_condition = layers.elementwise_add(cast_cond, broadcast_zeros)
broadcast_condition = layers.cast(broadcast_condition, 'bool')
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.where(condition, x, y) return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y)
else: else:
helper = LayerHelper("where", **locals()) helper = LayerHelper("where", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
type='where', type='where',
inputs={'Condition': condition, inputs={
'X': x, 'Condition': broadcast_condition,
'Y': y}, 'X': broadcast_x,
'Y': broadcast_y
},
outputs={'Out': [out]}) outputs={'Out': [out]})
return out
else:
cond_int = layers.cast(condition, x.dtype)
cond_not_int = layers.cast(layers.logical_not(condition), x.dtype)
out1 = layers.elementwise_mul(x, cond_int)
out2 = layers.elementwise_mul(y, cond_not_int)
out = layers.elementwise_add(out1, out2)
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册