未验证 提交 c921a812 编写于 作者: C Chen Weihang 提交者: GitHub

fix conv nd error (#42933)

上级 615d931c
...@@ -129,10 +129,13 @@ def _conv_nd(x, ...@@ -129,10 +129,13 @@ def _conv_nd(x,
if bias is not None: if bias is not None:
channel_dim = channel_dim + len( channel_dim = channel_dim + len(
x.shape) if channel_dim < 0 else channel_dim x.shape) if channel_dim < 0 else channel_dim
if len(bias.shape) < len(x.shape):
tmp_bias = _C_ops.final_state_reshape( tmp_bias = _C_ops.final_state_reshape(
bias, bias.shape + bias, bias.shape +
[1 for i in range(len(x.shape) - channel_dim - 1)]) [1 for i in range(len(x.shape) - channel_dim - 1)])
return _C_ops.final_state_add(pre_bias, tmp_bias) return _C_ops.final_state_add(pre_bias, tmp_bias)
else:
return _C_ops.final_state_add(pre_bias, bias)
else: else:
return pre_bias return pre_bias
if in_dynamic_mode(): if in_dynamic_mode():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册