From 722d9e57490553c9ffab8d3b5a4e40d8ae1ac969 Mon Sep 17 00:00:00 2001 From: Hongkun Yu Date: Sat, 14 Dec 2019 23:37:46 -0800 Subject: [PATCH] Clearly demarcate contrib symbols from standard tf symbols by importing them directly. PiperOrigin-RevId: 285618209 --- official/recommendation/neumf_model.py | 5 ++--- official/utils/misc/distribution_utils.py | 16 ---------------- 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/official/recommendation/neumf_model.py b/official/recommendation/neumf_model.py index 5d3c82c66..bb3c6f26e 100644 --- a/official/recommendation/neumf_model.py +++ b/official/recommendation/neumf_model.py @@ -115,8 +115,7 @@ def neumf_model_fn(features, labels, mode, params): beta2=params["beta2"], epsilon=params["epsilon"]) if params["use_tpu"]: - # TODO(seemuch): remove this contrib import - optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) + optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer) mlperf_helper.ncf_print(key=mlperf_helper.TAGS.MODEL_HP_LOSS_FN, value=mlperf_helper.TAGS.BCE) @@ -274,7 +273,7 @@ def _get_estimator_spec_with_metrics(logits, # type: tf.Tensor use_tpu_spec) if use_tpu_spec: - return tf.contrib.tpu.TPUEstimatorSpec( + return tf.estimator.tpu.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.EVAL, loss=cross_entropy, eval_metrics=(metric_fn, [in_top_k, ndcg, metric_weights])) diff --git a/official/utils/misc/distribution_utils.py b/official/utils/misc/distribution_utils.py index 82c38e105..610346c9a 100644 --- a/official/utils/misc/distribution_utils.py +++ b/official/utils/misc/distribution_utils.py @@ -283,14 +283,6 @@ def set_up_synthetic_data(): _monkey_patch_dataset_method(tf.distribute.MirroredStrategy) _monkey_patch_dataset_method( tf.distribute.experimental.MultiWorkerMirroredStrategy) - # TODO(tobyboyd): Remove when contrib.distribute is all in core. - if hasattr(tf, 'contrib'): - _monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy) - _monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy) - _monkey_patch_dataset_method( - tf.contrib.distribute.CollectiveAllReduceStrategy) - else: - print('Contrib missing: Skip monkey patch tf.contrib.distribute.*') def undo_set_up_synthetic_data(): @@ -298,14 +290,6 @@ def undo_set_up_synthetic_data(): _undo_monkey_patch_dataset_method(tf.distribute.MirroredStrategy) _undo_monkey_patch_dataset_method( tf.distribute.experimental.MultiWorkerMirroredStrategy) - # TODO(tobyboyd): Remove when contrib.distribute is all in core. - if hasattr(tf, 'contrib'): - _undo_monkey_patch_dataset_method(tf.contrib.distribute.MirroredStrategy) - _undo_monkey_patch_dataset_method(tf.contrib.distribute.OneDeviceStrategy) - _undo_monkey_patch_dataset_method( - tf.contrib.distribute.CollectiveAllReduceStrategy) - else: - print('Contrib missing: Skip remove monkey patch tf.contrib.distribute.*') def configure_cluster(worker_hosts=None, task_index=-1): -- GitLab