未验证 提交 15db2195 编写于 作者: S saberkun 提交者: GitHub

Merged commit includes the following changes: (#6926)

250713045  by hongkuny<hongkuny@google.com>:

    TPU util

--

PiperOrigin-RevId: 250713045
上级 d76e39e7
......@@ -27,12 +27,14 @@ from absl import flags
from absl import logging
import tensorflow as tf
# Import BERT model libraries.
from official.bert import bert_models
from official.bert import input_pipeline
from official.bert import model_saving_utils
from official.bert import model_training_utils
from official.bert import modeling
from official.bert import optimization
from official.bert import tpu_lib
flags.DEFINE_enum(
'mode', 'train_and_eval', ['train_and_eval', 'export_only'],
......@@ -59,9 +61,7 @@ flags.DEFINE_string(
'Path to the directory, where trainined model will be '
'exported.')
flags.DEFINE_enum(
'strategy_type',
'mirror',
['tpu', 'mirror'],
'strategy_type', 'mirror', ['tpu', 'mirror'],
'Distribution Strategy type to use for training. `tpu` uses '
'TPUStrategy for running on TPUs, `mirror` uses GPUs with '
'single host.')
......@@ -227,6 +227,7 @@ def run_bert(strategy, input_meta_data):
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
......@@ -237,14 +238,13 @@ def main(_):
if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
logging.info('Use TPU at %s', FLAGS.tpu if FLAGS.tpu is not None else '')
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu)
tf.config.experimental_connect_to_host(cluster_resolver.master()) # pylint: disable=line-too-long
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
else:
raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type)
run_bert(strategy, input_meta_data)
......
......@@ -19,16 +19,19 @@ from __future__ import division
from __future__ import print_function
import functools
from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
# Import BERT model libraries.
from official.bert import bert_models
from official.bert import input_pipeline
from official.bert import model_training_utils
from official.bert import modeling
from official.bert import optimization
from official.bert import tpu_lib
flags.DEFINE_string('input_files', None,
'File path to retrieve training data for pre-training.')
......@@ -40,9 +43,7 @@ flags.DEFINE_string(
'are stored. If not specified, save to /tmp/bert20/.'))
flags.DEFINE_string('tpu', '', 'TPU address to connect to.')
flags.DEFINE_enum(
'strategy_type',
'mirror',
['tpu', 'mirror'],
'strategy_type', 'mirror', ['tpu', 'mirror'],
'Distribution Strategy type to use for training. `tpu` uses '
'TPUStrategy for running on TPUs, `mirror` uses GPUs with '
'single host.')
......@@ -157,21 +158,20 @@ def run_bert_pretrain(strategy):
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
if not FLAGS.model_dir:
FLAGS.model_dir = '/tmp/bert20/'
strategy = None
if FLAGS.strategy_type == 'tpu':
logging.info('Use TPU at %s',
FLAGS.tpu if FLAGS.tpu is not None else '')
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu)
tf.config.experimental_connect_to_host(cluster_resolver.master()) # pylint: disable=line-too-long
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
elif FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
else:
raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type)
if strategy:
print('***** Number of cores used : ', strategy.num_replicas_in_sync)
......
......@@ -27,6 +27,7 @@ from absl import flags
from absl import logging
import tensorflow as tf
# Import BERT model libraries.
from official.bert import bert_models
from official.bert import input_pipeline
from official.bert import model_training_utils
......@@ -34,6 +35,7 @@ from official.bert import modeling
from official.bert import optimization
from official.bert import squad_lib
from official.bert import tokenization
from official.bert import tpu_lib
flags.DEFINE_bool('do_train', False, 'Whether to run training.')
flags.DEFINE_bool('do_predict', False, 'Whether to run eval on the dev set.')
......@@ -54,9 +56,7 @@ flags.DEFINE_string(
'init_checkpoint', None,
'Initial checkpoint (usually from a pre-trained BERT model).')
flags.DEFINE_enum(
'strategy_type',
'mirror',
['tpu', 'mirror'],
'strategy_type', 'mirror', ['tpu', 'mirror'],
'Distribution Strategy type to use for training. `tpu` uses '
'TPUStrategy for running on TPUs, `mirror` uses GPUs with '
'single host.')
......@@ -306,23 +306,18 @@ def predict_squad(strategy, input_meta_data):
def main(_):
# Users should always run this script under TF 2.x
assert tf.version.VERSION.startswith('2.')
with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
input_meta_data = json.loads(reader.read().decode('utf-8'))
strategy = None
if FLAGS.strategy_type == 'tpu':
logging.info('Use TPU at %s',
FLAGS.tpu if FLAGS.tpu is not None else '')
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu)
tf.config.experimental_connect_to_host(cluster_resolver.master()) # pylint: disable=line-too-long
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
if FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'tpu':
# Initialize TPU System.
cluster_resolver = tpu_lib.tpu_initialize(FLAGS.tpu)
strategy = tf.distribute.experimental.TPUStrategy(
cluster_resolver, steps_per_run=FLAGS.steps_per_run)
elif FLAGS.strategy_type == 'mirror':
strategy = tf.distribute.MirroredStrategy()
elif FLAGS.strategy_type == 'multi_worker_mirror':
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
else:
raise ValueError('The distribution strategy type is not supported: %s' %
FLAGS.strategy_type)
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Initializes TPU system for TF 2.0."""
import tensorflow as tf
def tpu_initialize(tpu_address):
"""Initializes TPU for TF 2.0 training.
Args:
tpu_address: string, bns address of TPU workers.
Returns:
A TPUClusterResolver.
"""
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address)
tf.config.experimental_connect_to_host(cluster_resolver.master())
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
return cluster_resolver
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册