提交 20101930 编写于 作者: R Ruomei Yan

Address review comments from mid June

上级 7dfef01d
......@@ -905,7 +905,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.batch_size = 128 * 8
self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self):
......@@ -918,7 +918,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.batch_size = 256 * 8
self._run_and_report_benchmark()
def benchmark_8_gpu_tweaked(self):
......@@ -942,7 +942,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu')
FLAGS.batch_size = 128 * 8 # 8 GPUs
FLAGS.batch_size = 128 * 8
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_amp(self):
......@@ -956,7 +956,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.batch_size = 256 * 8
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_tweaked(self):
......@@ -982,7 +982,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.batch_size = 256 * 8
self._run_and_report_benchmark()
def benchmark_8_gpu_fp16_tweaked(self):
......@@ -994,7 +994,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.batch_size = 256 * 8
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 40
self._run_and_report_benchmark()
......@@ -1009,7 +1009,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir(
'benchmark_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.batch_size = 256 * 8
FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 40
......@@ -1025,7 +1025,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.batch_size = 256 * 8
self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16_tweaked(self):
......@@ -1038,7 +1038,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.batch_size = 256 * 8
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 48
self._run_and_report_benchmark()
......@@ -1074,7 +1074,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir(
'benchmark_xla_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.batch_size = 256 * 8
FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 48
......@@ -1871,7 +1871,6 @@ class KerasClusteringBenchmarkRealBase(Resnet50KerasBenchmarkBase):
'report_accuracy_metrics': False,
'data_dir': os.path.join(root_data_dir, 'imagenet'),
'clustering_method': 'selective_clustering',
'number_of_clusters': 256,
'train_steps': 110,
'log_steps': 10,
})
......
......@@ -26,7 +26,7 @@ from absl import flags
from absl import logging
import tensorflow as tf
from tensorflow_model_optimization.python.core.clustering.keras import cluster
import tensorflow_model_optimization as tfmot
from official.modeling import performance
from official.utils.flags import core as flags_core
......@@ -39,56 +39,38 @@ from official.vision.image_classification.resnet import imagenet_preprocessing
from official.vision.image_classification.resnet import resnet_model
def selective_layers_to_cluster(model):
last_3conv2d_layers_to_cluster = [
layer.name
for layer in model.layers
def cluster_last_three_conv2d_layers(model):
last_three_conv2d_layers = [
layer for layer in model.layers
if isinstance(layer, tf.keras.layers.Conv2D) and
not isinstance(layer, tf.keras.layers.DepthwiseConv2D)
]
last_3conv2d_layers_to_cluster = last_3conv2d_layers_to_cluster[-3:]
return last_3conv2d_layers_to_cluster
def selective_clustering_clone_wrapper(clustering_params1, clustering_params2,
model):
def apply_clustering_to_conv2d_but_depthwise(layer):
layers_list = selective_layers_to_cluster(model)
if layer.name in layers_list:
if layer.name != layers_list[-1]:
print("Wrapped layer " + layer.name +
" with " +
str(clustering_params1["number_of_clusters"]) + " clusters.")
return cluster.cluster_weights(layer, **clustering_params1)
else:
print("Wrapped layer " + layer.name +
" with " +
str(clustering_params2["number_of_clusters"]) + " clusters.")
return cluster.cluster_weights(layer, **clustering_params2)
return layer
return apply_clustering_to_conv2d_but_depthwise
def cluster_model_selectively(model, selective_layers_to_cluster,
clustering_params1, clustering_params2):
result_layer_model = tf.keras.models.clone_model(
model,
clone_function=selective_clustering_clone_wrapper(clustering_params1,
clustering_params2,
model),
)
return result_layer_model
]
last_three_conv2d_layers = last_three_conv2d_layers[-3:]
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
clustering_params1 = {
'number_of_clusters': 256,
'cluster_centroids_init': CentroidInitialization.LINEAR
}
clustering_params2 = {
'number_of_clusters': 32,
'cluster_centroids_init': CentroidInitialization.LINEAR
}
def cluster_fn(layer):
if layer not in last_three_conv2d_layers:
return layer
if layer == last_three_conv2d_layers[0] or layer == last_three_conv2d_layers[1]:
clustered = cluster_weights(layer, **clustering_params1)
print("Clustered {} with {} clusters".format(layer.name, clustering_params1['number_of_clusters']))
else:
clustered = cluster_weights(layer, **clustering_params2)
print("Clustered {} with {} clusters".format(layer.name, clustering_params2['number_of_clusters']))
return clustered
def get_selectively_clustered_model(model, clustering_params1,
clustering_params2):
clustered_model = cluster_model_selectively(model,
selective_layers_to_cluster,
clustering_params1,
clustering_params2)
return clustered_model
return tf.keras.models.clone_model(model, clone_function=cluster_fn)
def run(flags_obj):
......@@ -244,12 +226,8 @@ def run(flags_obj):
layers=tf.keras.layers)
elif flags_obj.model == 'mobilenet_pretrained':
model = tf.keras.applications.mobilenet.MobileNet(
alpha=1.0,
depth_multiplier=1,
dropout=1e-7,
include_top=True,
weights='imagenet',
pooling=None,
classes=1000,
layers=tf.keras.layers)
......@@ -277,16 +255,7 @@ def run(flags_obj):
if dtype != tf.float32 or flags_obj.fp16_implementation == 'graph_rewrite':
raise NotImplementedError(
'Clustering is currently only supported on dtype=tf.float32.')
clustering_params1 = {
'number_of_clusters': flags_obj.number_of_clusters,
'cluster_centroids_init': 'linear'
}
clustering_params2 = {
'number_of_clusters': 32,
'cluster_centroids_init': 'linear'
}
model = get_selectively_clustered_model(model, clustering_params1,
clustering_params2)
model = cluster_last_three_conv2d_layers(model)
elif flags_obj.clustering_method:
raise NotImplementedError(
'Only selective_clustering is implemented.')
......@@ -324,7 +293,6 @@ def run(flags_obj):
num_eval_steps = None
validation_data = None
# if not strategy and flags_obj.explicit_gpu_placement:
if not strategy and flags_obj.explicit_gpu_placement:
# TODO(b/135607227): Add device scope automatically in Keras training loop
# when not using distribution strategy.
......@@ -350,7 +318,7 @@ def run(flags_obj):
model = tfmot.sparsity.keras.strip_pruning(model)
if flags_obj.clustering_method:
model = cluster.strip_clustering(model)
model = tfmot.clustering.keras.strip_clustering(model)
if flags_obj.enable_checkpoint_and_export:
if dtype == tf.bfloat16:
......@@ -363,14 +331,6 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__()
if flags_obj.clustering_method:
if flags_obj.save_files_to:
keras_file = os.path.join(flags_obj.save_files_to, 'clustered.h5')
else:
keras_file = './clustered.h5'
print('Saving clustered and stripped model to: ', keras_file)
tf.keras.models.save_model(model, keras_file)
stats = common.build_stats(history, eval_output, callbacks)
return stats
......
......@@ -26,7 +26,7 @@ from official.benchmark.models import resnet_imagenet_main
from official.utils.testing import integration
from official.vision.image_classification.resnet import imagenet_preprocessing
# TBC: joint clustering and tuning is not supported yet so only one flag should be selected
@parameterized.parameters(
"resnet",
# "resnet_polynomial_decay", b/151854314
......
......@@ -9,7 +9,7 @@ psutil>=5.4.3
py-cpuinfo>=3.3.0
scipy>=0.19.1
tensorflow-hub>=0.6.0
tensorflow-model-optimization>=0.2.1
tensorflow-model-optimization>=0.4.1
tensorflow-datasets
tensorflow-addons
dataclasses
......
......@@ -355,11 +355,8 @@ def define_pruning_flags():
def define_clustering_flags():
"""Define flags for clustering methods."""
flags.DEFINE_string('clustering_method', None,
'None (no clustering) or selective_clustering.')
flags.DEFINE_integer('number_of_clusters', 256,
'Number of clusters used in each layer.')
flags.DEFINE_string('save_files_to', None,
'The path to save Keras models and tflite models.')
'None (no clustering) or selective_clustering'\
'(cluster last three Conv2D layers of the model).')
def get_synth_input_fn(height,
width,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册