提交 f2d00a47 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add partitioner that partitions a variable such that no partition gets less...

Add partitioner that partitions a variable such that no partition gets less than the given minimum size of chunk.

Swith from 'layers.legacy_fully_connected' to 'layers.fully_connected' in _DNNLinearCombinedBaseEstimator. Also, partition the weights of the fully connected layers using the added partitioner.
Change: 125492475
上级 95237d1d
......@@ -41,7 +41,9 @@ from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import training
......@@ -247,24 +249,32 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
self._get_dnn_feature_columns(),
weight_collections=[self._dnn_weight_collection])
for layer_id, num_hidden_units in enumerate(self._dnn_hidden_units):
net = layers.legacy_fully_connected(
net,
num_hidden_units,
activation_fn=self._dnn_activation_fn,
weight_collections=[self._dnn_weight_collection],
bias_collections=[self._dnn_weight_collection],
name="hiddenlayer_%d" % layer_id)
if self._dnn_dropout is not None and is_training:
net = layers.dropout(
op_scope = "hiddenlayer_%d" % layer_id
with variable_scope.variable_op_scope(
[net], op_scope,
partitioner=partitioned_variables.min_max_variable_partitioner(
max_partitions=self._config.num_ps_replicas)):
net = layers.fully_connected(
net,
keep_prob=(1.0 - self._dnn_dropout))
self._add_hidden_layer_summary(net, "hiddenlayer_%d" % layer_id)
logit = layers.legacy_fully_connected(
net,
self._num_label_columns(),
weight_collections=[self._dnn_weight_collection],
bias_collections=[self._dnn_weight_collection],
name="dnn_logit")
num_hidden_units,
activation_fn=self._dnn_activation_fn,
variables_collections=[self._dnn_weight_collection],
scope=op_scope)
if self._dnn_dropout is not None and is_training:
net = layers.dropout(
net,
keep_prob=(1.0 - self._dnn_dropout))
self._add_hidden_layer_summary(net, op_scope)
with variable_scope.variable_op_scope(
[net], "dnn_logit",
partitioner=partitioned_variables.min_max_variable_partitioner(
max_partitions=self._config.num_ps_replicas)):
logit = layers.fully_connected(
net,
self._num_label_columns(),
activation_fn=None,
variables_collections=[self._dnn_weight_collection],
scope="dnn_logit")
self._add_hidden_layer_summary(logit, "dnn_logit")
return logit
......
......@@ -131,6 +131,93 @@ class PartitionerCreatorsTest(tf.test.TestCase):
self.assertEqual(len(v3str_list), 4)
self.assertAllEqual(v3str_part, (1, 1, 1, 4))
def _testMinMaxVariablePartitioner(self, max_partitions, axis, min_slice_size,
var_name, var_shape,
expected_axis_shards, expected_partitions):
partitioner = tf.min_max_variable_partitioner(max_partitions=max_partitions,
axis=axis,
min_slice_size=min_slice_size)
with tf.variable_scope("root", partitioner=partitioner):
v0 = tf.get_variable(var_name, dtype=tf.float32, shape=var_shape)
v0_list = v0._get_variable_list()
v0_part = v0._get_partitions()
self.assertEqual(len(v0_list), expected_axis_shards)
self.assertAllEqual(v0_part, expected_partitions)
def testMinMaxVariablePartitioner(self):
with self.test_session():
# Partitioning a variable of shape=[2048] with a minimum of 2K per slice.
self._testMinMaxVariablePartitioner(max_partitions=100, axis=0,
min_slice_size=2 << 10,
var_name="v0_0", var_shape=[2048],
expected_axis_shards=4,
expected_partitions=[4])
# Partitioning a variable of shape=[2048, 1024] with a minimum of 256K per
# slice.
self._testMinMaxVariablePartitioner(max_partitions=100, axis=0,
min_slice_size=256 << 10,
var_name="v0", var_shape=[2048, 1024],
expected_axis_shards=32,
expected_partitions=[32, 1])
# max_partitions restricts partitioning of the variable.
self._testMinMaxVariablePartitioner(max_partitions=16, axis=0,
min_slice_size=256 << 10,
var_name="v1_max",
var_shape=[2048, 1024],
expected_axis_shards=16,
expected_partitions=[16, 1])
self._testMinMaxVariablePartitioner(max_partitions=1, axis=0,
min_slice_size=256 << 10,
var_name="v2_max",
var_shape=[2048, 1024],
expected_axis_shards=1,
expected_partitions=[1, 1])
# Reducing/Increasing min_slice_size proportionately increases/reduces the
# number of partitions.
self._testMinMaxVariablePartitioner(max_partitions=100, axis=0,
min_slice_size=128 << 10,
var_name="v3_slice",
var_shape=[2048, 1024],
expected_axis_shards=64,
expected_partitions=[64, 1])
self._testMinMaxVariablePartitioner(max_partitions=100, axis=0,
min_slice_size=512 << 10,
var_name="v4_slice",
var_shape=[2048, 1024],
expected_axis_shards=16,
expected_partitions=[16, 1])
# Partitioning the variable along a different axis.
self._testMinMaxVariablePartitioner(max_partitions=100, axis=1,
min_slice_size=256 << 10,
var_name="v5_axis",
var_shape=[64, 1024, 1, 3],
expected_axis_shards=3,
expected_partitions=[1, 3, 1, 1])
self._testMinMaxVariablePartitioner(max_partitions=100, axis=3,
min_slice_size=256 << 10,
var_name="v6_axis",
var_shape=[64, 1024, 1, 3],
expected_axis_shards=3,
expected_partitions=[1, 1, 1, 3])
# Can not partition the variable more than what its shape allows.
self._testMinMaxVariablePartitioner(max_partitions=100, axis=0,
min_slice_size=256 << 10,
var_name="v7_shape",
var_shape=[16, 128, 1024],
expected_axis_shards=16,
expected_partitions=[16, 1, 1])
self._testMinMaxVariablePartitioner(max_partitions=100, axis=0,
min_slice_size=256 << 10,
var_name="v8_shape",
var_shape=[4, 512, 1024],
expected_axis_shards=4,
expected_partitions=[4, 1, 1])
def _IotaInitializer(shape, dtype=tf.float32):
assert dtype == tf.float32
......
......@@ -61,7 +61,11 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
__all__ = ["create_partitioned_variables", "variable_axis_size_partitioner"]
__all__ = [
"create_partitioned_variables",
"variable_axis_size_partitioner",
"min_max_variable_partitioner",
]
def variable_axis_size_partitioner(
......@@ -148,6 +152,69 @@ def variable_axis_size_partitioner(
return _partitioner
def min_max_variable_partitioner(max_partitions=1, axis=0,
min_slice_size=256 << 10,
bytes_per_string_element=16):
"""Partitioner to allocate minimum size per slice.
Returns a partitioner that partitions the variable of given shape and dtype
such that each partition has a minimum of `min_slice_size` slice of the
variable. The maximum number of such partitions (upper bound) is given by
`max_partitions`.
Args:
max_partitions: Upper bound on the number of partitions. Defaults to 1.
axis: Axis along which to partition the variable. Defaults to 0.
min_slice_size: Minimum size of the variable slice per partition. Defaults
to 256K.
bytes_per_string_element: If the `Variable` is of type string, this provides
an estimate of how large each scalar in the `Variable` is.
Returns:
A partition function usable as the `partitioner` argument to
`variable_scope`, `get_variable`, and `get_partitioned_variable_list`.
"""
def _partitioner(shape, dtype):
"""Partitioner that partitions list for a variable of given shape and type.
Ex: Consider partitioning a variable of type float32 with
shape=[1024, 1024].
If `max_partitions` >= 16, this function would return
[(1024 * 1024 * 4) / (256 * 1024), 1] = [16, 1].
If `max_partitions` < 16, this function would return
[`max_partitions`, 1].
Args:
shape: Shape of the variable.
dtype: Type of the variable.
Returns:
List of partitions for each axis (currently only one axis can be
partitioned).
Raises:
ValueError: If axis to partition along does not exist for the variable.
"""
if axis >= len(shape):
raise ValueError("Can not partition variable along axis %d when shape is "
"only %s" % (axis, shape))
if dtype.base_dtype == dtypes.string:
bytes_per_element = bytes_per_string_element
else:
bytes_per_element = dtype.size
total_size_bytes = shape.num_elements() * bytes_per_element
partitions = total_size_bytes / min_slice_size
partitions_list = [1] * len(shape)
# We can not partition the variable beyond what its shape or
# `max_partitions` allows.
partitions_list[axis] = max(1, min(shape[axis].value,
max_partitions,
int(math.ceil(partitions))))
return partitions_list
return _partitioner
def create_partitioned_variables(
shape, slicing, initializer, dtype=dtypes.float32,
trainable=True, collections=None, name=None, reuse=None):
......
......@@ -68,6 +68,7 @@ create variables contingent on certain conditions.
## Variable Partitioners for Sharding
@@variable_axis_size_partitioner
@@min_max_variable_partitioner
## Sparse Variable Updates
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册