未验证 提交 b85c9b56 编写于 作者: X xiongkun 提交者: GitHub

[ Dy2Static ] Fix bugs when select inputs meeting different shape or undefined-var (#45916)

* fix select_input with different shape errors:
1. select_input_with_buildin_type directly return non-undefinedvar branch when meeting undefined var
2. the output shape of select_input is inferred from inputs.

* reverse the logic in select_input
上级 6833ecfe
...@@ -145,21 +145,6 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0): ...@@ -145,21 +145,6 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
need_check_feed=False) need_check_feed=False)
def create_undefined_var_like(variable):
""" create a undefined var with the same shape and dtype like varaible.
"""
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
var = data_layer_not_check(unique_name.generate("undefined_var"),
variable.shape, variable.dtype)
var.stop_gradient = False
helper = LayerHelper('create_undefined_var_like', **locals())
saved_block_ids = helper.main_program.current_block_idx
helper.main_program.current_block_idx = 0
assign(RETURN_NO_VALUE_MAGIC_NUM, var)
helper.main_program.current_block_idx = saved_block_ids
return var
def create_undefined_variable(): def create_undefined_variable():
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
var = data_layer_not_check(unique_name.generate("undefined_var"), [1], var = data_layer_not_check(unique_name.generate("undefined_var"), [1],
......
...@@ -70,6 +70,25 @@ def select_output(input, outputs, mask): ...@@ -70,6 +70,25 @@ def select_output(input, outputs, mask):
return outputs return outputs
def _select_input_infer_shape(first_shape, second_shape):
"""
This function infer the output shape by following algorithm:
1. if the dims is different, raise a error.
2. compare axis one by one:
if a == b: we set axis to a
if a != b: we set axis to -1
for compatibility,non declarative mode, we just return second_shape.
"""
if len(first_shape) != len(second_shape):
warnings.warn(
f"the input shapes of select_input should have the same rank, but get {first_shape}, {second_shape}"
)
return second_shape
out_shape = list(
map(lambda a, b: a if a == b else -1, first_shape, second_shape))
return out_shape
def select_input(inputs, mask): def select_input(inputs, mask):
""" """
**select_input** **select_input**
...@@ -89,13 +108,15 @@ def select_input(inputs, mask): ...@@ -89,13 +108,15 @@ def select_input(inputs, mask):
check_type(inputs, 'inputs', (list, tuple), 'select_input') check_type(inputs, 'inputs', (list, tuple), 'select_input')
check_variable_and_dtype(mask, 'mask', ['int32'], 'select_input') check_variable_and_dtype(mask, 'mask', ['int32'], 'select_input')
input_dtype = inputs[1].dtype # Select input should expand the shape. If it is - 1 and valid number, use - 1 first. If the dim is different, an error will be reported directly
input_shape = inputs[1].shape #assert inputs[0].dtype == inputs[1].dtype, f"Expect the inputs should have the same dtype, but get {inputs[0].dtype} and {inputs[1].dtype}"
input_type = inputs[1].type output_shape = _select_input_infer_shape(inputs[0].shape, inputs[1].shape)
output_dtype = inputs[1].dtype
output_type = inputs[1].type
out = helper.create_variable(dtype=input_dtype, out = helper.create_variable(dtype=output_dtype,
shape=input_shape, shape=output_shape,
type=input_type) type=output_type)
helper.append_op(type='select_input', helper.append_op(type='select_input',
inputs={ inputs={
'X': inputs, 'X': inputs,
...@@ -105,9 +126,9 @@ def select_input(inputs, mask): ...@@ -105,9 +126,9 @@ def select_input(inputs, mask):
return out return out
def select_input_with_buildin_type(inputs, mask): def select_input_with_buildin_type(inputs, mask, name):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_var_like from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
false_var, true_var = inputs false_var, true_var = inputs
if isinstance(false_var, UndefinedVar) and isinstance( if isinstance(false_var, UndefinedVar) and isinstance(
...@@ -117,7 +138,11 @@ def select_input_with_buildin_type(inputs, mask): ...@@ -117,7 +138,11 @@ def select_input_with_buildin_type(inputs, mask):
return None return None
if isinstance(false_var, Variable) and isinstance(true_var, Variable): if isinstance(false_var, Variable) and isinstance(true_var, Variable):
try:
return select_input(inputs, mask) return select_input(inputs, mask)
except Exception as e:
raise RuntimeError(
f"Exceptions throwed while doing select_input on {name}:\n{e}")
elif (isinstance(false_var, (support_ret_buildin_type)) elif (isinstance(false_var, (support_ret_buildin_type))
and isinstance(false_var, type(true_var))): and isinstance(false_var, type(true_var))):
...@@ -148,24 +173,19 @@ def select_input_with_buildin_type(inputs, mask): ...@@ -148,24 +173,19 @@ def select_input_with_buildin_type(inputs, mask):
if isinstance(a, UndefinedVar): return a if isinstance(a, UndefinedVar): return a
return to_static_variable(a) return to_static_variable(a)
def create_like_if_undefined_var(a, b): true_var, false_var = to_static_variable(true_var), to_static_variable(
if isinstance(a, UndefinedVar): return create_undefined_var_like(b) false_var)
return a inputs = [false_var, true_var]
# TODO(xiongkun): add warning here.
true_var, false_var = create_var_if_not_undefined_var(
true_var), create_var_if_not_undefined_var(false_var)
inputs = [
create_like_if_undefined_var(false_var, true_var),
create_like_if_undefined_var(true_var, false_var)
]
else: else:
raise TypeError( raise TypeError(
"Unsupported return type of true_fn and false_fn in cond: false_var " "Unsupported return type of true_fn and false_fn in cond: false_var "
"returned by fasle_fn is '{}' and true_var of true_fn is '{}'". "returned by fasle_fn is '{}' and true_var of true_fn is '{}'".
format(type(false_var), type(true_var))) format(type(false_var), type(true_var)))
try:
return select_input(inputs, mask) return select_input(inputs, mask)
except Exception as e:
raise RuntimeError(
f"Exceptions throwed while doing select_input on {name}:\n{e}")
def split_lod_tensor(input, mask, level=0): def split_lod_tensor(input, mask, level=0):
...@@ -2658,9 +2678,16 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): ...@@ -2658,9 +2678,16 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
.format(return_name, e)) .format(return_name, e))
mask = cast(pred, dtype='int32') mask = cast(pred, dtype='int32')
merge_func = lambda false_var, true_var: select_input_with_buildin_type( merge_func = lambda name, false_var, true_var: select_input_with_buildin_type(
[false_var, true_var], mask) [false_var, true_var], mask, name)
merged_output = map_structure(merge_func, false_output, true_output)
def merge_every_var_list(false_vars, true_vars, name):
return map_structure(partial(merge_func, name), false_vars, true_vars)
merged_output = list(
map(merge_every_var_list, to_sequence(false_output),
to_sequence(true_output), to_sequence(return_names)))
merged_output = pack_sequence_as(false_output, flatten(merged_output))
return merged_output return merged_output
......
...@@ -164,7 +164,7 @@ def nested_if_else(x_v): ...@@ -164,7 +164,7 @@ def nested_if_else(x_v):
if paddle.mean(y).numpy()[0] < batch_size: if paddle.mean(y).numpy()[0] < batch_size:
y = fluid.layers.abs(y) y = fluid.layers.abs(y)
else: else:
tmp = fluid.layers.fill_constant([feat_size], tmp = fluid.layers.fill_constant(y.shape,
dtype='float32', dtype='float32',
value=-1) value=-1)
y = y - tmp y = y - tmp
...@@ -273,7 +273,7 @@ class NetWithControlFlowIf(fluid.dygraph.Layer): ...@@ -273,7 +273,7 @@ class NetWithControlFlowIf(fluid.dygraph.Layer):
[hidden_dim], dtype='float32', value=9) [hidden_dim], dtype='float32', value=9)
y = fluid.layers.abs(y) y = fluid.layers.abs(y)
else: else:
tmp = fluid.layers.fill_constant([5], tmp = fluid.layers.fill_constant(y.shape,
dtype='float32', dtype='float32',
value=-1) value=-1)
y = y - tmp y = y - tmp
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册