From 9066828b1b119492c314f1302335e75c4c72fda1 Mon Sep 17 00:00:00 2001 From: 123malin Date: Fri, 20 Nov 2020 16:31:31 +0800 Subject: [PATCH] test=develop, bug fix for embeddings padding (#28708) * test=develop, bug fix for embeddings padding * fix raise Value for Embedding Change-Id: I6d343fceee369a5796ad59cca5c91fdd15429125 Co-authored-by: seiriosPlus --- python/paddle/nn/functional/input.py | 13 +++++++------ python/paddle/nn/layer/common.py | 8 +++++--- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index ab5a000a2bf..40b9441c2dc 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -192,6 +192,13 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): x=label, weight=weight, sparse=True, name="embedding") """ + padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( + weight.shape[0] + padding_idx) + + if padding_idx >= weight.shape[0] or padding_idx < -weight.shape[0]: + raise ValueError("padding_idx must be within [-{}, {})".format( + weight.shape[0], weight.shape[0])) + if in_dygraph_mode(): return core.ops.lookup_table_v2( weight, x, 'is_sparse', sparse, 'is_distributed', False, @@ -206,12 +213,6 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): remote_prefetch = sparse and (not is_distributed) tmp = helper.create_variable_for_type_inference(dtype) - padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( - weight.shape[0] + padding_idx) - - if padding_idx >= weight.shape[0] or padding_idx < -weight.shape[0]: - raise ValueError("padding_idx must be within [-{}, {})".format( - weight.shape[0], weight.shape[0])) helper.append_op( type='lookup_table_v2', diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 6e3910745e1..cf8aa7a66e3 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -1103,8 +1103,7 @@ class Embedding(layers.Layer): self._embedding_dim = embedding_dim self._sparse = sparse self._is_distributed = False - self._padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( - num_embeddings + padding_idx) + self._padding_idx = padding_idx if self._num_embeddings <= 0: raise ValueError("num_embeddings must be gather than 0") @@ -1112,7 +1111,10 @@ class Embedding(layers.Layer): if self._embedding_dim <= 0: raise ValueError("embedding_dim must be gather than 0") - if self._padding_idx >= num_embeddings or self._padding_idx < -num_embeddings: + padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( + num_embeddings + padding_idx) + + if padding_idx >= num_embeddings or padding_idx < -num_embeddings: raise ValueError("padding_idx must be within [-{}, {})".format( num_embeddings, num_embeddings)) -- GitLab