提交 c0bdb378 编写于 作者: A A. Unique TensorFlower

Using better version of l2 loss to avoid reshape, concat and split ops.

Also adding support for CTL mode in the borg file.

PiperOrigin-RevId: 284075404
上级 f079ed2e
......@@ -274,13 +274,12 @@ def run(flags_obj):
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
if flags_obj.single_l2_loss_op:
filtered_variables = [
tf.reshape(v, (-1,))
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.add_n([
tf.nn.l2_loss(v)
for v in trainable_variables
if 'bn' not in v.name
]
l2_loss = resnet_model.L2_WEIGHT_DECAY * 2 * tf.nn.l2_loss(
tf.concat(filtered_variables, axis=0))
])
loss += (l2_loss / num_replicas)
else:
loss += (tf.reduce_sum(model.losses) / num_replicas)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册