提交 722d9e57 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Clearly demarcate contrib symbols from standard tf symbols by importing them directly.

PiperOrigin-RevId: 285618209
上级 e5c71d51
......@@ -115,8 +115,7 @@ def neumf_model_fn(features, labels, mode, params):
beta2=params["beta2"],
epsilon=params["epsilon"])
if params["use_tpu"]:
# TODO(seemuch): remove this contrib import
optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)
mlperf_helper.ncf_print(key=mlperf_helper.TAGS.MODEL_HP_LOSS_FN,
value=mlperf_helper.TAGS.BCE)
......@@ -274,7 +273,7 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor
use_tpu_spec)
if use_tpu_spec:
return tf.contrib.tpu.TPUEstimatorSpec(
return tf.estimator.tpu.TPUEstimatorSpec(
mode=tf.estimator.ModeKeys.EVAL,
loss=cross_entropy,
eval_metrics=(metric_fn, [in_top_k, ndcg, metric_weights]))
......
......@@ -283,14 +283,6 @@ def set_up_synthetic_data():
_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'):
_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
_monkey_patch_dataset_method(
tf.contrib.distribute.CollectiveAllReduceStrategy)
else:
print('Contrib missing: Skip monkey patch tf.contrib.distribute.*')
def undo_set_up_synthetic_data():
......@@ -298,14 +290,6 @@ def undo_set_up_synthetic_data():
_undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(
tf.distribute.experimental.MultiWorkerMirroredStrategy)
# TODO(tobyboyd): Remove when contrib.distribute is all in core.
if hasattr(tf, 'contrib'):
_undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy)
_undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy)
_undo_monkey_patch_dataset_method(
tf.contrib.distribute.CollectiveAllReduceStrategy)
else:
print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*')
def configure_cluster(worker_hosts=None, task_index=-1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册