未验证 提交 9066828b 编写于 作者: 1 123malin 提交者: GitHub

test=develop, bug fix for embeddings padding (#28708)

* test=develop, bug fix for embeddings padding

* fix raise Value for Embedding

Change-Id: I6d343fceee369a5796ad59cca5c91fdd15429125
Co-authored-by: NseiriosPlus <tangwei12@baidu.com>
上级 655d5eb1
......@@ -192,6 +192,13 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
x=label, weight=weight, sparse=True, name="embedding")
"""
padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
weight.shape[0] + padding_idx)
if padding_idx >= weight.shape[0] or padding_idx < -weight.shape[0]:
raise ValueError("padding_idx must be within [-{}, {})".format(
weight.shape[0], weight.shape[0]))
if in_dygraph_mode():
return core.ops.lookup_table_v2(
weight, x, 'is_sparse', sparse, 'is_distributed', False,
......@@ -206,12 +213,6 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None):
remote_prefetch = sparse and (not is_distributed)
tmp = helper.create_variable_for_type_inference(dtype)
padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
weight.shape[0] + padding_idx)
if padding_idx >= weight.shape[0] or padding_idx < -weight.shape[0]:
raise ValueError("padding_idx must be within [-{}, {})".format(
weight.shape[0], weight.shape[0]))
helper.append_op(
type='lookup_table_v2',
......
......@@ -1103,8 +1103,7 @@ class Embedding(layers.Layer):
self._embedding_dim = embedding_dim
self._sparse = sparse
self._is_distributed = False
self._padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
num_embeddings + padding_idx)
self._padding_idx = padding_idx
if self._num_embeddings <= 0:
raise ValueError("num_embeddings must be gather than 0")
......@@ -1112,7 +1111,10 @@ class Embedding(layers.Layer):
if self._embedding_dim <= 0:
raise ValueError("embedding_dim must be gather than 0")
if self._padding_idx >= num_embeddings or self._padding_idx < -num_embeddings:
padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else (
num_embeddings + padding_idx)
if padding_idx >= num_embeddings or padding_idx < -num_embeddings:
raise ValueError("padding_idx must be within [-{}, {})".format(
num_embeddings, num_embeddings))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册