未验证 提交 1e5fec39 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] Fix get var in prim when list of single tensor (#56114)

* fix get var in prim

* fix stack test case
上级 0c7fdda9
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import typing
import paddle
......@@ -132,11 +133,13 @@ INT_DTYPE_2_STRING = {
}
def get_var_block(block, names):
def get_var_block(block, names, is_tensor_list=None):
assert isinstance(names, list)
if len(names) == 0:
return None
elif len(names) == 1:
if is_tensor_list:
return [block.var(names[0])]
return block.var(names[0])
else:
return [block.var(name) for name in names]
......@@ -179,7 +182,7 @@ def _get_args_values(op, phi_name):
"get attrs' values for api args' values"
args = op_info[phi_name]
args_list = args["args"].split(",")
inputs = []
inputs = collections.OrderedDict()
attrs = []
for item in args_list:
......@@ -212,9 +215,9 @@ def _get_args_values(op, phi_name):
"inputs" in op_content.keys()
and arg_name in op_content["inputs"].keys()
):
inputs.append(op_content["inputs"][arg_name])
inputs[op_content["inputs"][arg_name]] = arg_type
else:
inputs.append(arg_name)
inputs[arg_name] = arg_type
else:
attr_value = _get_attr_value(op, arg_type, arg_name)
attrs.append(attr_value)
......@@ -237,9 +240,16 @@ def prepare_python_api_arguments(op):
phi_name = op.type
inputs, attrs = _get_args_values(op, phi_name)
res = []
for item in inputs:
for item, tensor_type in inputs.items():
if item in op.input_names:
res.append(get_var_block(op.block, op.input(item)))
if tensor_type == "Tensor[]":
res.append(
get_var_block(
op.block, op.input(item), is_tensor_list=True
)
)
else:
res.append(get_var_block(op.block, op.input(item)))
else:
# Note: in some cases, inputs may be optional, thus assign None. Such case must be recorded.
res.append(None)
......
......@@ -375,5 +375,20 @@ class TestStackAPI_ZeroDim(unittest.TestCase):
paddle.enable_static()
class TestStackListOfSingleTensor(unittest.TestCase):
def setUp(self):
paddle.disable_static()
paddle.seed(2022)
self.x = [paddle.randn((4, 2, 6), dtype="float32")]
def test_list_single_tensor(self):
expect = paddle.stack(self.x)
paddle.fluid.core._set_prim_all_enabled(True)
st_model = paddle.jit.to_static(paddle.stack)
actual = st_model(self.x)
np.testing.assert_allclose(expect, actual)
paddle.enable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册