提交 8f75f24a 编写于 作者: P Peter Hawkins 提交者: TensorFlower Gardener

Add a function to create training variable slots from an initializer. Use it...

Add a function to create training variable slots from an initializer. Use it from the Adagrad optimizer.
Change: 150210582
上级 d478dfab
......@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training import training_ops
......@@ -60,10 +60,11 @@ class AdagradOptimizer(optimizer.Optimizer):
def _create_slots(self, var_list):
for v in var_list:
with ops.colocate_with(v):
val = constant_op.constant(self._initial_accumulator_value,
shape=v.get_shape(),
dtype=v.dtype.base_dtype)
self._get_or_make_slot(v, val, "accumulator", self._name)
dtype = v.dtype.base_dtype
init = init_ops.constant_initializer(self._initial_accumulator_value,
dtype=dtype)
self._get_or_make_slot_with_initializer(v, init, v.get_shape(), dtype,
"accumulator", self._name)
def _prepare(self):
self._learning_rate_tensor = ops.convert_to_tensor(self._learning_rate,
......
......@@ -727,6 +727,28 @@ class Optimizer(object):
named_slots[_var_key(var)] = slot_creator.create_slot(var, val, op_name)
return named_slots[_var_key(var)]
def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
slot_name, op_name):
"""Find or create a slot for a variable, using an Initializer.
Args:
var: A `Variable` object.
initializer: An `Initializer`. The initial value of the slot.
shape: Shape of the initial value of the slot.
dtype: Type of the value of the slot.
slot_name: Name for the slot.
op_name: Name to use when scoping the Variable that
needs to be created for the slot.
Returns:
A `Variable` object.
"""
named_slots = self._slot_dict(slot_name)
if _var_key(var) not in named_slots:
named_slots[_var_key(var)] = slot_creator.create_slot_with_initializer(
var, initializer, shape, dtype, op_name)
return named_slots[_var_key(var)]
def _zeros_slot(self, var, slot_name, op_name):
"""Find or create a slot initialized with 0.0.
......
......@@ -115,6 +115,40 @@ def create_slot(primary, val, name, colocate_with_primary=True):
return _create_slot_var(primary, val, "", validate_shape, None, None)
def create_slot_with_initializer(primary, initializer, shape, dtype, name,
colocate_with_primary=True):
"""Creates a slot initialized using an `Initializer`.
The type of the slot is determined by the given value.
Args:
primary: The primary `Variable` or `Tensor`.
initializer: An `Initializer`. The initial value of the slot.
shape: Shape of the initial value of the slot.
dtype: Type of the value of the slot.
name: Name to use for the slot variable.
colocate_with_primary: Boolean. If True the slot is located
on the same device as `primary`.
Returns:
A `Variable` object.
"""
# Scope the slot name in the namespace of the primary variable.
# Set "primary.op.name + '/' + name" as default name, so the scope name of
# optimizer can be shared when reuse is True. Meanwhile when reuse is False
# and the same name has been previously used, the scope name will add '_N'
# as suffix for unique identifications.
validate_shape = shape.is_fully_defined()
with variable_scope.variable_scope(None, primary.op.name + "/" + name):
if colocate_with_primary:
with ops.colocate_with(primary):
return _create_slot_var(primary, initializer, "", validate_shape, shape,
dtype)
else:
return _create_slot_var(primary, initializer, "", validate_shape, shape,
dtype)
def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
"""Create a slot initialized to 0 with same shape as the primary object.
......@@ -134,16 +168,11 @@ def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):
slot_shape = (slot_shape if slot_shape.is_fully_defined()
else array_ops.shape(primary.initialized_value()))
if slot_shape.is_fully_defined():
val = init_ops.zeros_initializer(dtype)
validate_shape = True
shape = slot_shape
initializer = init_ops.zeros_initializer(dtype)
return create_slot_with_initializer(
primary, initializer, slot_shape, dtype, name,
colocate_with_primary=colocate_with_primary)
else:
val = array_ops.zeros(slot_shape, dtype=dtype)
validate_shape = False
shape = None
with variable_scope.variable_scope(None, primary.op.name + "/" + name):
if colocate_with_primary:
with ops.colocate_with(primary):
return _create_slot_var(primary, val, "", validate_shape, shape, dtype)
else:
return _create_slot_var(primary, val, "", validate_shape, shape, dtype)
return create_slot(primary, val, name,
colocate_with_primary=colocate_with_primary)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册