提交 9a748138 编写于 作者: A Alexandre Passos 提交者: GitHub

Merge pull request #8650 from suiyuan2009/add-distributed-aggregation-for-embedding_lookup_sparse

Add distributed aggregation for embedding lookup sparse
......@@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
......@@ -26,15 +28,18 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
__all__ = [
"safe_embedding_lookup_sparse", "scattered_embedding_lookup",
"scattered_embedding_lookup_sparse", "embedding_lookup_unique"
"scattered_embedding_lookup_sparse", "embedding_lookup_unique",
"embedding_lookup_sparse_with_distributed_aggregation"
]
......@@ -548,3 +553,326 @@ def _sampled_scattered_embedding_lookup_sparse(params,
return math_ops.unsorted_segment_sum(embeddings, segment_ids,
num_segments=num_segments,
name=name_scope)
def embedding_lookup_sparse_with_distributed_aggregation(params, sp_ids,
sp_weights, partition_strategy="mod", name=None, combiner=None,
max_norm=None):
"""Computes embeddings for the given ids and weights.
Embeddings belonging to same param are aggregated on that device first. This
op is intended to decrease data transmission and improve parallelism. See
`tf.nn.embedding_lookup_sparse` for the functionality and example of this op.
Args:
params: A single tensor representing the complete embedding tensor,
or a list of P tensors all of same shape except for the first dimension,
representing sharded embedding tensors. Alternatively, a
`PartitionedVariable`, created by partitioning along dimension 0. Each
element must be appropriately sized for the given `partition_strategy`.
sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
where N is typically batch size and M is arbitrary.
sp_weights: either a SparseTensor of float / double weights, or None to
indicate all weights should be taken to be 1. If specified, sp_weights
must have exactly the same shape and indices as sp_ids.
partition_strategy: A string specifying the partitioning strategy, relevant
if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
is `"mod"`. See `tf.nn.embedding_lookup` for more details.
name: Optional name for the op.
combiner: A string specifying the reduction op. Currently "mean", "sqrtn"
and "sum" are supported.
"sum" computes the weighted sum of the embedding results for each row.
"mean" is the weighted sum divided by the total weight.
"sqrtn" is the weighted sum divided by the square root of the sum of the
squares of the weights.
max_norm: If not None, each embedding is normalized to have l2 norm equal
to max_norm before combining.
Returns:
A dense tensor representing the combined embeddings for the
sparse ids. For each row in the dense tensor represented by sp_ids, the op
looks up the embeddings for all ids in that row, multiplies them by the
corresponding weight, and combines these embeddings as specified.
Raises:
TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
None nor SparseTensor.
ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
"""
if combiner is None:
logging.warn("The default value of combiner will change from \"mean\" "
"to \"sqrtn\" after 2016/11/01.")
combiner = "mean"
if combiner not in ("mean", "sqrtn", "sum"):
raise ValueError("combiner must be one of 'mean', 'sqrtn' or 'sum'")
if isinstance(params, variables.PartitionedVariable):
params = list(params) # Iterate to get the underlying Variables.
if not isinstance(params, list):
params = [params]
if not isinstance(sp_ids, sparse_tensor.SparseTensor):
raise TypeError("sp_ids must be SparseTensor")
ignore_weights = sp_weights is None
if not ignore_weights:
if not isinstance(sp_weights, sparse_tensor.SparseTensor):
raise TypeError("sp_weights must be either None or SparseTensor")
sp_ids.values.get_shape().assert_is_compatible_with(
sp_weights.values.get_shape())
sp_ids.indices.get_shape().assert_is_compatible_with(
sp_weights.indices.get_shape())
sp_ids.dense_shape.get_shape().assert_is_compatible_with(
sp_weights.dense_shape.get_shape())
# TODO(yleon): Add enhanced node assertions to verify that sp_ids and
# sp_weights have equal indices and shapes.
with ops.name_scope(name, "embedding_lookup_sparse",
params + [sp_ids]) as name:
segment_ids = sp_ids.indices[:, 0]
if segment_ids.dtype != dtypes.int32:
segment_ids = math_ops.cast(segment_ids, dtypes.int32)
ids = sp_ids.values
if ignore_weights:
ids, idx = array_ops.unique(ids)
else:
idx = None
weights = None if ignore_weights else sp_weights.values
embeddings = _embedding_lookup_with_distributed_aggregation(
params, ids, partition_strategy=partition_strategy, max_norm=max_norm,
weights=weights, idx=idx, segment_ids=segment_ids)
# Set weights to all one if ignore weights.
if ignore_weights:
weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1)
if weights.dtype != embeddings.dtype:
weights = math_ops.cast(weights, embeddings.dtype)
# Reshape weights.
ones = array_ops.fill(
array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
0)
orig_weights_shape = weights.get_shape()
weights = array_ops.reshape(weights, bcast_weights_shape)
if embeddings.get_shape().ndims is not None:
weights.set_shape(orig_weights_shape.concatenate(
[1 for _ in range(embeddings.get_shape().ndims - 1)]))
if combiner == "mean":
weight_sum = math_ops.segment_sum(weights, segment_ids)
embeddings = math_ops.div(embeddings, weight_sum)
elif combiner == "sqrtn":
weights_squared = math_ops.pow(weights, 2)
weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
weight_sum_sqrt = math_ops.sqrt(weight_sum)
embeddings = math_ops.div(embeddings, weight_sum_sqrt)
elif combiner != "sum":
assert False, "Unrecognized combiner"
return embeddings
def _do_gather(params, ids, validate_indices=True, name=None):
"""Deals with doing gather differently for resource variables."""
if isinstance(params, resource_variable_ops.ResourceVariable):
return params.sparse_read(ids, name=name)
return array_ops.gather(
params, ids, name=name, validate_indices=validate_indices)
def _embedding_lookup_with_distributed_aggregation(params, ids,
partition_strategy="mod", name=None, validate_indices=True, max_norm=None,
weights=None, idx=None, segment_ids=None):
""" Lookup helper for embedding_lookup_sparse_with_distributed_aggregation."""
if params is None or params == []: # pylint: disable=g-explicit-bool-comparison
raise ValueError("Need at least one param")
if isinstance(params, variables.PartitionedVariable):
params = list(params) # Iterate to get the underlying Variables.
if not isinstance(params, list):
params = [params]
def maybe_normalize(x):
if max_norm is not None:
if x.get_shape().ndims is not None:
ndims = x.get_shape().ndims
else:
ndims = array_ops.size(array_ops.shape(x))
return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
return x
with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation",
params + [ids]) as name:
np = len(params) # Number of partitions
# Preserve the resource variable status to avoid accidental dense reads.
if not any(isinstance(p, resource_variable_ops.ResourceVariable)
for p in params):
params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
if np == 1:
with ops.colocate_with(params[0]):
ret = maybe_normalize(
_do_gather(
params[0], ids, validate_indices=validate_indices))
ignore_weights = weights is None
if not ignore_weights:
if weights.dtype != ret.dtype:
weights = math_ops.cast(weights, ret.dtype)
# Reshape to allow broadcast
ones = array_ops.fill(
array_ops.expand_dims(array_ops.rank(ret) - 1, 0), 1)
bcast_weights_shape = array_ops.concat(
[array_ops.shape(weights), ones], 0)
orig_weights_shape = weights.get_shape()
weights = array_ops.reshape(weights, bcast_weights_shape)
# Set weights shape after reshape
if ret.get_shape().ndims is not None:
weights.set_shape(orig_weights_shape.concatenate(
[1 for _ in range(ret.get_shape().ndims - 1)]))
ret *= weights
return math_ops.segment_sum(ret, segment_ids, name=name)
else:
return math_ops.sparse_segment_sum(ret, idx, segment_ids, name=name)
else:
ids = ops.convert_to_tensor(ids, name="ids")
flat_ids = array_ops.reshape(ids, [-1])
original_indices = math_ops.range(array_ops.size(flat_ids))
# Create p_assignments and set new_ids depending on the strategy.
if partition_strategy == "mod":
p_assignments = flat_ids % np
new_ids = flat_ids // np
elif partition_strategy == "div":
# Compute num_total_ids as the sum of dim-0 of params, then assign to
# partitions based on a constant number of ids per partition. Optimize
# if we already know the full shape statically.
dim_0_size = params[0].get_shape()[0]
for p in xrange(1, np):
dim_0_size += params[p].get_shape()[0]
if dim_0_size.value:
num_total_ids = constant_op.constant(dim_0_size.value, flat_ids.dtype)
else:
dim_0_sizes = []
for p in xrange(np):
if params[p].get_shape()[0].value is not None:
dim_0_sizes.append(params[p].get_shape()[0].value)
else:
with ops.colocate_with(params[p]):
dim_0_sizes.append(array_ops.shape(params[p])[0])
num_total_ids = math_ops.reduce_sum(
math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
ids_per_partition = num_total_ids // np
extras = num_total_ids % np
p_assignments = math_ops.maximum(
flat_ids // (ids_per_partition + 1),
(flat_ids - extras) // ids_per_partition)
# Emulate a conditional using a boolean indicator tensor
is_in_first_extras_partitions = math_ops.cast(
p_assignments < extras, flat_ids.dtype)
new_ids = (
is_in_first_extras_partitions * (
flat_ids % (ids_per_partition + 1)) +
(1 - is_in_first_extras_partitions) * (
(flat_ids - extras) % ids_per_partition))
else:
raise ValueError("Unrecognized partition strategy: " +
partition_strategy)
# Cast partition assignments to int32 for use in dynamic_partition.
# There really should not be more than 2^32 partitions.
p_assignments = math_ops.cast(p_assignments, dtypes.int32)
# Partition list of ids based on assignments into np separate lists
gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
# Similarly, partition the original indices.
pindices = data_flow_ops.dynamic_partition(original_indices,
p_assignments, np)
# Do np separate lookups, finding embeddings for plist[p] in params[p]
partitioned_result = []
for p in xrange(np):
with ops.colocate_with(params[p]):
partitioned_result.append(
_do_gather(params[p], gather_ids[p],
validate_indices=validate_indices))
ignore_weights = weights is None
if not ignore_weights:
# Partition weights according to pindices.
partitioned_weight = []
for p in xrange(np):
partitioned_weight.append(array_ops.gather(weights, pindices[p]))
# Reshape each partition result.
element_shape = params[0].get_shape()[1:]
for p in params[1:]:
element_shape = element_shape.merge_with(p.get_shape()[1:])
if element_shape.is_fully_defined():
for p in xrange(np):
with ops.colocate_with(params[p]):
partitioned_result[p] = array_ops.reshape(partitioned_result[p],
array_ops.concat(
[array_ops.shape(pindices[p]), element_shape], 0))
else:
with ops.colocate_with(params[0]):
params_shape = array_ops.shape(params[0])
for p in xrange(np):
with ops.colocate_with(params[p]):
partitioned_result[p] = array_ops.reshape(partitioned_result[p],
array_ops.concat([array_ops.shape(pindices[p]),
array_ops.slice(params_shape, [1], [-1])], 0))
# Normalize each partition result.
for p in xrange(np):
with ops.colocate_with(params[p]):
partitioned_result[p] = maybe_normalize(partitioned_result[p])
if not ignore_weights:
# Multiply each partition result with partition weights.
for p in xrange(np):
with ops.colocate_with(params[p]):
if partitioned_weight[p].dtype != partitioned_result[p].dtype:
partitioned_weight[p] = math_ops.cast(partitioned_weight[p],
partitioned_result[p].dtype)
# Reshape partition weights.
ones = array_ops.fill(
array_ops.expand_dims(
array_ops.rank(partitioned_result[p]) - 1, 0), 1)
bcast_weights_shape = array_ops.concat(
[array_ops.shape(partitioned_weight[p]), ones], 0)
orig_weights_shape = partitioned_weight[p].get_shape()
partitioned_weight[p] = array_ops.reshape(partitioned_weight[p],
bcast_weights_shape)
if partitioned_result[p].get_shape().ndims is not None:
partitioned_weight[p].set_shape(orig_weights_shape.concatenate(
[1 for _ in range(
partitioned_result[p].get_shape().ndims - 1)]))
partitioned_result[p] *= partitioned_weight[p]
partitioned_segment_ids = []
for p in xrange(np):
if not ignore_weights:
# Partition segment_ids according to pindices.
p_segment_ids = array_ops.gather(segment_ids, pindices[p])
# Number the p_segment_ids to meet segment_sum's requirements. Note
# that unique_p_segment_ids contains unique segment ids of this
# partiton and these ids' order is unchanged.
unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
p_segment_ids)
partitioned_segment_ids.append(unique_p_segment_ids)
# segment_sum this partition's result.
with ops.colocate_with(params[p]):
partitioned_result[p] = math_ops.segment_sum(
partitioned_result[p], unique_p_segment_idx)
else:
# When ignore weights, we need to get indexs of elements in idx and
# segment_ids.
_, exclude_idx = array_ops.setdiff1d(idx, pindices[p])
all_idx = math_ops.range(array_ops.shape(idx)[0])
_, include_idx = array_ops.setdiff1d(all_idx, exclude_idx)
# Gather segment_ids and idx according to indexs.
p_segment_ids = array_ops.gather(segment_ids, include_idx)
p_idx = array_ops.gather(idx, include_idx)
# Number the p_segment_ids, same as ignore_weights case above.
unique_p_segment_ids, unique_p_segment_idx = array_ops.unique(
p_segment_ids)
_, unique_p_idx_idx = array_ops.unique(p_idx)
partitioned_segment_ids.append(unique_p_segment_ids)
with ops.colocate_with(params[p]):
partitioned_result[p] = math_ops.sparse_segment_sum(
partitioned_result[p], unique_p_idx_idx, unique_p_segment_idx)
# Concat each partition's segment_ids and result for final segment_sum.
concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0)
concat_partitioned_result = array_ops.concat(partitioned_result, 0)
return math_ops.unsorted_segment_sum(
concat_partitioned_result, concat_segment_ids,
math_ops.reduce_max(concat_segment_ids) + 1, name=name)
......@@ -32,9 +32,11 @@ from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.platform import test
from tensorflow.python.util import compat
class SafeEmbeddingLookupSparseTest(test.TestCase):
......@@ -563,5 +565,224 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
self.assertAllClose(result.eval(), result_abc.eval())
def _PName(param_id):
return "p" + str(param_id)
def _EmbeddingParams(num_shards,
vocab_size,
dtype=dtypes.float32,
shape=None,
use_shapeless_placeholder=False):
p = []
params = {}
feed_dict = {}
if not shape:
shape = [10]
for i in range(num_shards):
shard_shape = [vocab_size // num_shards] + shape
if i < vocab_size % num_shards: # Excess goes evenly on the first shards
shard_shape[0] += 1
param_name = _PName(i)
if use_shapeless_placeholder:
param = array_ops.placeholder(dtype, shape=None, name=param_name)
else:
param = constant_op.constant(
1.0, shape=shard_shape, dtype=dtype, name=param_name)
p.append(param)
np_type = "f" if dtype == dtypes.float32 else "d"
val = (np.random.rand(*shard_shape).astype(np_type)) + 1
params[param_name + ":0"] = val
feed_dict[param.name] = val
return p, params, feed_dict
def _EmbeddingResult(params,
id_vals,
num_shards,
vocab_size,
partition_strategy="mod",
weight_vals=None):
if weight_vals is None:
weight_vals = np.copy(id_vals)
weight_vals.fill(1)
values = []
weights = []
weights_squared = []
for ids, wts in zip(id_vals, weight_vals):
value_aggregation = None
weight_aggregation = None
squared_weight_aggregation = None
if isinstance(ids, compat.integral_types):
ids = [ids]
wts = [wts]
for i, weight_value in zip(ids, wts):
if partition_strategy == "mod":
val = np.copy(params[_PName(i % num_shards) + ":0"][
i // num_shards, :]) * weight_value
elif partition_strategy == "div":
ids_per_partition, extras = divmod(vocab_size, num_shards)
threshold = extras * (ids_per_partition + 1)
if i < threshold:
partition = i // (ids_per_partition + 1)
offset = i % (ids_per_partition + 1)
else:
partition = extras + (i - threshold) // ids_per_partition
offset = (i - threshold) % ids_per_partition
val = np.copy(params[_PName(partition) + ":0"][
offset, :]) * weight_value
else:
assert False
if value_aggregation is None:
assert weight_aggregation is None
assert squared_weight_aggregation is None
value_aggregation = val
weight_aggregation = weight_value
squared_weight_aggregation = weight_value * weight_value
else:
assert weight_aggregation is not None
assert squared_weight_aggregation is not None
value_aggregation += val
weight_aggregation += weight_value
squared_weight_aggregation += weight_value * weight_value
values.append(value_aggregation)
weights.append(weight_aggregation)
weights_squared.append(squared_weight_aggregation)
values = np.array(values).astype(np.float32)
weights = np.array(weights).astype(np.float32)
weights_squared = np.array(weights_squared).astype(np.float32)
return values, weights, weights_squared
class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
def _RandomIdsAndWeights(self, batch_size, vocab_size):
max_val_per_entry = 6
vals_per_batch_entry = np.random.randint(
1, max_val_per_entry, size=batch_size)
num_vals = np.sum(vals_per_batch_entry)
ids = np.random.randint(vocab_size, size=num_vals)
weights = 1 + np.random.rand(num_vals)
indices = []
for batch_entry, num_val in enumerate(vals_per_batch_entry):
for val_index in range(num_val):
indices.append([batch_entry, val_index])
shape = [batch_size, max_val_per_entry]
sp_ids = sparse_tensor_lib.SparseTensor(
constant_op.constant(indices, dtypes.int64),
constant_op.constant(ids, dtypes.int32),
constant_op.constant(shape, dtypes.int64))
sp_weights = sparse_tensor_lib.SparseTensor(
constant_op.constant(indices, dtypes.int64),
constant_op.constant(weights, dtypes.float32),
constant_op.constant(shape, dtypes.int64))
return sp_ids, sp_weights, ids, weights, vals_per_batch_entry
def _GroupByBatchEntry(self, vals, vals_per_batch_entry):
grouped_vals = []
index = 0
for num_val in vals_per_batch_entry:
grouped_vals.append(list(vals[index:(index + num_val)]))
index += num_val
return grouped_vals
def testEmbeddingLookupSparse(self):
vocab_size = 13
batch_size = 10
param_shape = [2, 5]
expected_lookup_result_shape = [None] + param_shape
sp_ids, sp_weights, ids, weights, vals_per_batch_entry = (
self._RandomIdsAndWeights(batch_size, vocab_size))
grouped_ids = self._GroupByBatchEntry(ids, vals_per_batch_entry)
grouped_weights = self._GroupByBatchEntry(weights, vals_per_batch_entry)
grouped_ignored_weights = self._GroupByBatchEntry(
np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)
for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 5],
["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
[True, False]):
with self.test_session():
p, params, feed_dict = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
embedding_sum = \
embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
p,
sp_ids,
None if ignore_weights else sp_weights,
combiner=combiner)
self.assertEqual(embedding_sum.get_shape().as_list(),
expected_lookup_result_shape)
tf_embedding_sum = embedding_sum.eval(feed_dict=feed_dict)
np_embedding_sum, np_weight_sum, np_weight_sq_sum = _EmbeddingResult(
params,
grouped_ids,
num_shards,
vocab_size,
weight_vals=grouped_ignored_weights if ignore_weights else
grouped_weights)
if combiner == "mean":
np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1))
if combiner == "sqrtn":
np_embedding_sum /= np.reshape(
np.sqrt(np_weight_sq_sum), (batch_size, 1, 1))
self.assertAllClose(np_embedding_sum, tf_embedding_sum)
def testGradientsEmbeddingLookupSparse(self):
vocab_size = 12
batch_size = 4
param_shape = [2, 3]
sp_ids, sp_weights, _, _, _ = (
self._RandomIdsAndWeights(batch_size, vocab_size))
for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 3],
["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
[True, False]):
with self.test_session():
x, params, _ = _EmbeddingParams(
num_shards, vocab_size, shape=param_shape, dtype=dtype)
y = embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
x,
sp_ids,
None if ignore_weights else sp_weights,
combiner=combiner)
x_name = [_PName(i) for i in range(num_shards)]
x_init_value = [params[x_n + ":0"] for x_n in x_name]
x_shape = [i.shape for i in x_init_value]
y_shape = [batch_size] + list(params[_PName(0) + ":0"].shape[1:])
err = gradient_checker.compute_gradient_error(
x, x_shape, y, y_shape, x_init_value=x_init_value)
self.assertLess(err, 1e-5 if dtype == dtypes.float64 else 2e-3)
def testIncompatibleShapes(self):
with self.test_session():
x, _, _ = _EmbeddingParams(1, 10, dtype=dtypes.float32)
sp_ids = sparse_tensor_lib.SparseTensor(
constant_op.constant([[0, 0], [0, 1], [1, 0]], dtypes.int64),
constant_op.constant([0, 1, 2], dtypes.int32),
constant_op.constant([2, 2], dtypes.int64))
sp_weights = sparse_tensor_lib.SparseTensor(
constant_op.constant([[0, 0], [0, 1]], dtypes.int64),
constant_op.constant([12.0, 5.0], dtypes.float32),
constant_op.constant([1, 2], dtypes.int64))
with self.assertRaises(ValueError):
embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
x, sp_ids, sp_weights, combiner="mean")
if __name__ == "__main__":
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册