提交 55d41fd6 编写于 作者: H Hongkun Yu 提交者: A. Unique TensorFlower

Use distribution utils in XLNET

PiperOrigin-RevId: 331015243
上级 a8518117
......@@ -14,11 +14,6 @@
# ==============================================================================
"""XLNet classification finetuning runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools
# Import libraries
from absl import app
......@@ -34,7 +29,7 @@ from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import tpu_lib
from official.utils.misc import distribution_utils
flags.DEFINE_integer("n_class", default=2, help="Number of classes.")
flags.DEFINE_string(
......@@ -135,14 +130,9 @@ def get_metric_fn():
def main(unused_argv):
del unused_argv
if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu":
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.strategy_type,
tpu_address=FLAGS.tpu)
if strategy:
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
......
......@@ -12,12 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet classification finetuning runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
"""XLNet pretraining runner in tf2.0."""
import functools
import os
......@@ -34,7 +29,7 @@ from official.nlp.xlnet import optimization
from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import tpu_lib
from official.utils.misc import distribution_utils
flags.DEFINE_integer(
"num_predict",
......@@ -77,17 +72,11 @@ def get_pretrainxlnet_model(model_config, run_config):
def main(unused_argv):
del unused_argv
num_hosts = 1
if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu":
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
topology = FLAGS.tpu_topology.split("x")
total_num_core = 2 * int(topology[0]) * int(topology[1])
num_hosts = total_num_core // FLAGS.num_core_per_host
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.strategy_type,
tpu_address=FLAGS.tpu)
if FLAGS.strategy_type == "tpu":
num_hosts = strategy.extended.num_hosts
if strategy:
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
......
......@@ -14,11 +14,6 @@
# ==============================================================================
"""XLNet SQUAD finetuning runner in tf2.0."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools
import json
import os
......@@ -39,7 +34,7 @@ from official.nlp.xlnet import squad_utils
from official.nlp.xlnet import training_utils
from official.nlp.xlnet import xlnet_config
from official.nlp.xlnet import xlnet_modeling as modeling
from official.utils.misc import tpu_lib
from official.utils.misc import distribution_utils
flags.DEFINE_string(
"test_feature_path", default=None, help="Path to feature of test set.")
......@@ -217,14 +212,9 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
def main(unused_argv):
del unused_argv
if FLAGS.strategy_type == "mirror":
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == "tpu":
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
else:
raise ValueError("The distribution strategy type is not supported: %s" %
FLAGS.strategy_type)
strategy = distribution_utils.get_distribution_strategy(
distribution_strategy=FLAGS.strategy_type,
tpu_address=FLAGS.tpu)
if strategy:
logging.info("***** Number of cores used : %d",
strategy.num_replicas_in_sync)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册