From 6d1fb79dc483afcebd8207b708e0c844bd5da1a3 Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Sat, 9 Jan 2021 19:39:44 +0800 Subject: [PATCH] fix pad (#30231) --- python/paddle/nn/functional/common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 7319b860db..fac5ca2f79 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1276,6 +1276,9 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): x_dim = len(x.shape) + if mode == "constant" and isinstance(pad, list) and len(pad) == x_dim * 2: + return layers.pad(x, pad, pad_value=value) + assert x_dim in [ 3, 4, 5 ], "input tesor dimension must be in [3, 4, 5] but got {}".format(x_dim) @@ -1291,9 +1294,6 @@ def pad(x, pad, mode='constant', value=0, data_format="NCHW", name=None): unsqueezed_dim = [] - if mode == "constant" and isinstance(pad, list) and len(pad) == x_dim * 2: - return layers.pad(x, pad, pad_value=value) - if isinstance(pad, Variable): if data_format in ["NCL", "NCHW", "NCDHW"]: data_format = "NCDHW" -- GitLab