未验证 提交 61de8af8 编写于 作者: W Weilong Wu 提交者: GitHub

fix tuple input for _conv_nd (#44108)

上级 19902a12
...@@ -130,6 +130,10 @@ def _conv_nd(x, ...@@ -130,6 +130,10 @@ def _conv_nd(x,
if bias is not None: if bias is not None:
channel_dim = channel_dim + len( channel_dim = channel_dim + len(
x.shape) if channel_dim < 0 else channel_dim x.shape) if channel_dim < 0 else channel_dim
if isinstance(x, tuple):
x = x[0]
if isinstance(bias, tuple):
bias = bias[0]
if len(bias.shape) < len(x.shape): if len(bias.shape) < len(x.shape):
tmp_bias = _C_ops.final_state_reshape( tmp_bias = _C_ops.final_state_reshape(
bias, bias.shape + bias, bias.shape +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册