提交 822d64f0 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fix embedding_lookup() bug where normalization did not work with ids of rank != 1.

PiperOrigin-RevId: 157422220
上级 8cad6b82
......@@ -547,6 +547,31 @@ class EmbeddingLookupTest(test.TestCase):
sharded = embedding_ops.embedding_lookup(split_params, ids).eval()
self.assertAllEqual(simple, sharded)
def testHigherRankMaxNorm(self):
np.random.seed(8)
with self.test_session():
for params_shape in (12,), (6, 3):
params = 2 * np.ones(params_shape)
params_norm = params / np.sqrt(
np.sum(params*params, tuple(range(params.ndim)[1:]), keepdims=True))
for ids_shape in (), (3), (4, 3), (2, 3, 4):
ids = np.random.randint(
params.shape[0], size=np.prod(ids_shape, dtype=np.int64)).reshape(
ids_shape)
# Compare nonsharded to gather
simple = embedding_ops.embedding_lookup(
params, ids, max_norm=1.0).eval()
self.assertAllEqual(simple, array_ops.gather(params_norm, ids).eval())
# Run a few random sharded versions
for procs in 1, 2, 3:
stride = procs * math_ops.range(params.shape[0] // procs)
split_params = [
array_ops.gather(params, stride + p) for p in xrange(procs)
]
sharded = embedding_ops.embedding_lookup(
split_params, ids, max_norm=1.0).eval()
self.assertAllEqual(simple, sharded)
class EmbeddingLookupSparseTest(test.TestCase):
......
......@@ -103,14 +103,25 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
params = list(params) # Iterate to get the underlying Variables.
if not isinstance(params, list):
params = [params]
def maybe_normalize(x):
if max_norm is not None:
if x.get_shape().ndims is not None:
ndims = x.get_shape().ndims
else:
ndims = array_ops.size(array_ops.shape(x))
return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
return x
"""Normalizes the embeddings in x if max_norm is not None."""
if max_norm is None:
return x
static = True
ids_rank = ops.convert_to_tensor(ids).get_shape().ndims
if ids_rank is None:
ids_rank = array_ops.rank(ids)
static = False
x_rank = x.get_shape().ndims
if x_rank is None:
x_rank = array_ops.rank(x)
static = False
return clip_ops.clip_by_norm(
x, max_norm,
axes=list(range(ids_rank, x_rank)) if static
else math_ops.range(ids_rank, x_rank))
with ops.name_scope(name, "embedding_lookup", params + [ids]) as name:
np = len(params) # Number of partitions
# Preserve the resource variable status to avoid accidental dense reads.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册