提交 8eb7343c 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add support for non-trainable neural network layers.

This is useful for multi-task setups where a network is trained on one task and
then a subset of the layers are reused for a different (but related) task.
Change: 119310581
上级 b4698ac2
......@@ -51,6 +51,7 @@ def fully_connected(x,
weight_collections=(ops.GraphKeys.WEIGHTS,),
bias_collections=(ops.GraphKeys.BIASES,),
output_collections=(ops.GraphKeys.ACTIVATIONS,),
trainable=True,
weight_regularizer=None,
bias_regularizer=None):
# pylint: disable=anomalous-backslash-in-string
......@@ -107,6 +108,8 @@ def fully_connected(x,
weight_collections: List of graph collections to which weights are added.
bias_collections: List of graph collections to which biases are added.
output_collections: List of graph collections to which outputs are added.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
weight_regularizer: A regularizer like the result of
`l1_regularizer` or `l2_regularizer`. Used for weights.
bias_regularizer: A regularizer like the result of
......@@ -137,7 +140,8 @@ def fully_connected(x,
dtype=dtype,
initializer=weight_init,
collections=weight_collections,
regularizer=weight_regularizer)
regularizer=weight_regularizer,
trainable=trainable)
x_2_dim = x if len(dims) <= 2 else array_ops.reshape(x,
[-1, num_input_units])
y = standard_ops.matmul(x_2_dim, w)
......@@ -150,7 +154,8 @@ def fully_connected(x,
dtype=dtype,
initializer=bias_init,
collections=bias_collections,
regularizer=bias_regularizer)
regularizer=bias_regularizer,
trainable=trainable)
y = nn.bias_add(y, b)
......@@ -179,6 +184,7 @@ def convolution2d(x,
weight_collections=None,
bias_collections=None,
output_collections=None,
trainable=True,
weight_regularizer=None,
bias_regularizer=None):
"""Adds the parameters for a conv2d layer and returns the output.
......@@ -227,6 +233,8 @@ def convolution2d(x,
weight_collections: List of graph collections to which weights are added.
bias_collections: List of graph collections to which biases are added.
output_collections: List of graph collections to which outputs are added.
trainable: If `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
weight_regularizer: A regularizer like the result of
`l1_regularizer` or `l2_regularizer`. Used for weights.
bias_regularizer: A regularizer like the result of
......@@ -258,7 +266,8 @@ def convolution2d(x,
dtype=dtype,
initializer=weight_init,
collections=weight_collections,
regularizer=weight_regularizer)
regularizer=weight_regularizer,
trainable=trainable)
y = nn.conv2d(x, w, stride, padding)
......@@ -270,7 +279,8 @@ def convolution2d(x,
dtype=dtype,
initializer=bias_init,
collections=bias_collections,
regularizer=bias_regularizer)
regularizer=bias_regularizer,
trainable=trainable)
y = nn.bias_add(y, b)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册