未验证 提交 91a26272 编写于 作者: L liym27 提交者: GitHub

fix bug in select_input: set var type of output to be the same as input. test=develop (#23076)

上级 c4a6a0e2
...@@ -83,10 +83,14 @@ def select_input(inputs, mask): ...@@ -83,10 +83,14 @@ def select_input(inputs, mask):
if isinstance(inputs, list) or isinstance(inputs, tuple): if isinstance(inputs, list) or isinstance(inputs, tuple):
input_dtype = inputs[0].dtype input_dtype = inputs[0].dtype
input_shape = inputs[0].shape input_shape = inputs[0].shape
input_type = inputs[0].type
else: else:
input_dtype = inputs.dtype input_dtype = inputs.dtype
input_shape = inputs.shape input_shape = inputs.shape
out = helper.create_variable(dtype=input_dtype, shape=input_shape) input_type = inputs.type
out = helper.create_variable(
dtype=input_dtype, shape=input_shape, type=input_type)
helper.append_op( helper.append_op(
type='select_input', type='select_input',
inputs={'X': inputs, inputs={'X': inputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册