From 92253f11547bb1f9ff3356464b902da0c30a2abc Mon Sep 17 00:00:00 2001 From: ronnywang <524019753@qq.com> Date: Sat, 29 Jan 2022 18:53:04 +0800 Subject: [PATCH] fix paddle.where broadcast bug (#39182) --- python/paddle/tensor/search.py | 57 ++++++++++++++++++++++++---------- 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index e15d2d49d54..2a2e7d000a1 100644 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -585,26 +585,49 @@ def where(condition, x=None, y=None, name=None): condition_shape = list(condition.shape) x_shape = list(x.shape) y_shape = list(y.shape) + if x_shape == y_shape and condition_shape == x_shape: - if in_dygraph_mode(): - return _C_ops.where(condition, x, y) - else: - helper = LayerHelper("where", **locals()) - out = helper.create_variable_for_type_inference(dtype=x.dtype) - - helper.append_op( - type='where', - inputs={'Condition': condition, - 'X': x, - 'Y': y}, - outputs={'Out': [out]}) + broadcast_condition = condition + broadcast_x = x + broadcast_y = y + else: + if core.is_compiled_with_xpu(): + cond_int = layers.cast(condition, x.dtype) + cond_not_int = layers.cast(layers.logical_not(condition), x.dtype) + out1 = layers.elementwise_mul(x, cond_int) + out2 = layers.elementwise_mul(y, cond_not_int) + out = layers.elementwise_add(out1, out2) return out + + zeros_like_x = layers.zeros_like(x) + zeros_like_y = layers.zeros_like(y) + zeros_like_condition = layers.zeros_like(condition) + zeros_like_condition = layers.cast(zeros_like_condition, x.dtype) + cast_cond = layers.cast(condition, x.dtype) + + broadcast_zeros = layers.elementwise_add(zeros_like_x, zeros_like_y) + broadcast_zeros = layers.elementwise_add(broadcast_zeros, + zeros_like_condition) + broadcast_x = layers.elementwise_add(x, broadcast_zeros) + broadcast_y = layers.elementwise_add(y, broadcast_zeros) + broadcast_condition = layers.elementwise_add(cast_cond, broadcast_zeros) + broadcast_condition = layers.cast(broadcast_condition, 'bool') + + if in_dygraph_mode(): + return _C_ops.where(broadcast_condition, broadcast_x, broadcast_y) else: - cond_int = layers.cast(condition, x.dtype) - cond_not_int = layers.cast(layers.logical_not(condition), x.dtype) - out1 = layers.elementwise_mul(x, cond_int) - out2 = layers.elementwise_mul(y, cond_not_int) - out = layers.elementwise_add(out1, out2) + helper = LayerHelper("where", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='where', + inputs={ + 'Condition': broadcast_condition, + 'X': broadcast_x, + 'Y': broadcast_y + }, + outputs={'Out': [out]}) + return out -- GitLab