提交 f926be0a 编写于 作者: A A. Unique TensorFlower

Moved set_gpu_thread_mode_and_count from vision/image_classification/common.py to

utils/misc/keras_utils

PiperOrigin-RevId: 283885102
上级 91a1ce9b
......@@ -18,8 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import os
import time
from absl import logging
import tensorflow as tf
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import tf2
......@@ -79,18 +82,18 @@ class TimeHistory(tf.keras.callbacks.Callback):
elapsed_time = timestamp - self.start_time
examples_per_second = (self.batch_size * self.log_steps) / elapsed_time
self.timestamp_log.append(BatchTimestamp(self.global_steps, timestamp))
tf.compat.v1.logging.info(
logging.info(
"BenchmarkMetric: {'global step':%d, 'time_taken': %f,"
"'examples_per_second': %f}" %
(self.global_steps, elapsed_time, examples_per_second))
"'examples_per_second': %f}",
self.global_steps, elapsed_time, examples_per_second)
self.start_time = timestamp
def on_epoch_end(self, epoch, logs=None):
epoch_run_time = time.time() - self.epoch_start
self.epoch_runtime_log.append(epoch_run_time)
tf.compat.v1.logging.info(
"BenchmarkMetric: {'epoch':%d, 'time_taken': %f}" %
(epoch, epoch_run_time))
logging.info(
"BenchmarkMetric: {'epoch':%d, 'time_taken': %f}",
epoch, epoch_run_time)
def get_profiler_callback(model_dir, profile_steps, enable_tensorboard,
......@@ -110,7 +113,7 @@ def get_profiler_callback(model_dir, profile_steps, enable_tensorboard,
if start_step < 0 or start_step > stop_step:
raise ValueError(profile_steps_error_message)
if enable_tensorboard:
tf.compat.v1.logging.warn(
logging.warning(
'Both TensorBoard and profiler callbacks are used. Note that the '
'TensorBoard callback profiles the 2nd step (unless otherwise '
'specified). Please make sure the steps profiled by the two callbacks '
......@@ -143,14 +146,14 @@ class ProfilerCallback(tf.keras.callbacks.Callback):
if batch == self.start_step_in_epoch and self.should_start:
self.should_start = False
profiler.start()
tf.compat.v1.logging.info('Profiler started at Step %s', self.start_step)
logging.info('Profiler started at Step %s', self.start_step)
def on_batch_end(self, batch, logs=None):
if batch == self.stop_step_in_epoch and self.should_stop:
self.should_stop = False
results = profiler.stop()
profiler.save(self.log_dir, results)
tf.compat.v1.logging.info(
logging.info(
'Profiler saved profiles for steps between %s and %s to %s',
self.start_step, self.stop_step, self.log_dir)
......@@ -197,3 +200,31 @@ def set_config_v2(enable_xla=False):
def is_v2_0():
"""Returns true if using tf 2.0."""
return tf2.enabled()
def set_gpu_thread_mode_and_count(gpu_thread_mode,
datasets_num_private_threads,
num_gpus, per_gpu_thread_count):
"""Set GPU thread mode and count, and adjust dataset threads count."""
cpu_count = multiprocessing.cpu_count()
logging.info('Logical CPU cores: %s', cpu_count)
# Allocate private thread pool for each GPU to schedule and launch kernels
per_gpu_thread_count = per_gpu_thread_count or 2
os.environ['TF_GPU_THREAD_MODE'] = gpu_thread_mode
os.environ['TF_GPU_THREAD_COUNT'] = str(per_gpu_thread_count)
logging.info('TF_GPU_THREAD_COUNT: %s',
os.environ['TF_GPU_THREAD_COUNT'])
logging.info('TF_GPU_THREAD_MODE: %s',
os.environ['TF_GPU_THREAD_MODE'])
# Limit data preprocessing threadpool to CPU cores minus number of total GPU
# private threads and memory copy threads.
total_gpu_thread_count = per_gpu_thread_count * num_gpus
num_runtime_threads = num_gpus
if not datasets_num_private_threads:
datasets_num_private_threads = min(
cpu_count - total_gpu_thread_count - num_runtime_threads,
num_gpus * 8)
logging.info('Set datasets_num_private_threads to %s',
datasets_num_private_threads)
......@@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import multiprocessing
import os
from absl import flags
......@@ -174,32 +173,6 @@ class PiecewiseConstantDecayWithWarmup(
}
def set_gpu_thread_mode_and_count(flags_obj):
"""Set GPU thread mode and count, and adjust dataset threads count."""
cpu_count = multiprocessing.cpu_count()
tf.compat.v1.logging.info('Logical CPU cores: %s', cpu_count)
# Allocate private thread pool for each GPU to schedule and launch kernels
per_gpu_thread_count = flags_obj.per_gpu_thread_count or 2
os.environ['TF_GPU_THREAD_MODE'] = flags_obj.tf_gpu_thread_mode
os.environ['TF_GPU_THREAD_COUNT'] = str(per_gpu_thread_count)
tf.compat.v1.logging.info('TF_GPU_THREAD_COUNT: %s',
os.environ['TF_GPU_THREAD_COUNT'])
tf.compat.v1.logging.info('TF_GPU_THREAD_MODE: %s',
os.environ['TF_GPU_THREAD_MODE'])
# Limit data preprocessing threadpool to CPU cores minus number of total GPU
# private threads and memory copy threads.
total_gpu_thread_count = per_gpu_thread_count * flags_obj.num_gpus
num_runtime_threads = flags_obj.num_gpus
if not flags_obj.datasets_num_private_threads:
flags_obj.datasets_num_private_threads = min(
cpu_count - total_gpu_thread_count - num_runtime_threads,
flags_obj.num_gpus * 8)
tf.compat.v1.logging.info('Set datasets_num_private_threads to %s',
flags_obj.datasets_num_private_threads)
def get_optimizer(learning_rate=0.1):
"""Returns optimizer to use."""
# The learning_rate is overwritten at the beginning of each step by callback.
......
......@@ -83,7 +83,11 @@ def run(flags_obj):
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
common.set_gpu_thread_mode_and_count(flags_obj)
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=flags_obj.per_gpu_thread_count,
gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
num_gpus=flags_obj.num_gpus,
datasets_num_private_threads=flags_obj.datasets_num_private_threads)
common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj)
......
......@@ -54,7 +54,11 @@ def run(flags_obj):
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
common.set_gpu_thread_mode_and_count(flags_obj)
keras_utils.set_gpu_thread_mode_and_count(
per_gpu_thread_count=flags_obj.per_gpu_thread_count,
gpu_thread_mode=flags_obj.tf_gpu_thread_mode,
num_gpus=flags_obj.num_gpus,
datasets_num_private_threads=flags_obj.datasets_num_private_threads)
common.set_cudnn_batchnorm_mode()
dtype = flags_core.get_tf_dtype(flags_obj)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册