未验证 提交 e03171b7 编写于 作者: L littletomatodonkey 提交者: GitHub

fix pad (#30222)

上级 609c0222
...@@ -1276,6 +1276,9 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): ...@@ -1276,6 +1276,9 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None):
x_dim = len(x.shape) x_dim = len(x.shape)
if mode == "constant" and isinstance(pad, list) and len(pad) == x_dim * 2:
return layers.pad(x, pad, pad_value=value)
assert x_dim in [ assert x_dim in [
3, 4, 5 3, 4, 5
], "input tesor dimension must be in [3, 4, 5] but got {}".format(x_dim) ], "input tesor dimension must be in [3, 4, 5] but got {}".format(x_dim)
...@@ -1291,9 +1294,6 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): ...@@ -1291,9 +1294,6 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None):
unsqueezed_dim = [] unsqueezed_dim = []
if mode == "constant" and isinstance(pad, list) and len(pad) == x_dim * 2:
return layers.pad(x, pad, pad_value=value)
if isinstance(pad, Variable): if isinstance(pad, Variable):
if data_format in ["NCL", "NCHW", "NCDHW"]: if data_format in ["NCL", "NCHW", "NCDHW"]:
data_format = "NCDHW" data_format = "NCDHW"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册