未验证 提交 7313d7a9 编写于 作者: Z Zihan Wang 提交者: GitHub

Merge branch 'tensorflow:master' into master

...@@ -186,7 +186,7 @@ ...@@ -186,7 +186,7 @@
"exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')\n", "exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')\n",
"tfds_name = 'cifar10'\n", "tfds_name = 'cifar10'\n",
"ds,ds_info = tfds.load(\n", "ds,ds_info = tfds.load(\n",
"tfds_name\n", "tfds_name,\n",
"with_info=True)\n", "with_info=True)\n",
"ds_info" "ds_info"
] ]
......
...@@ -71,6 +71,7 @@ class OrbitExperimentRunner: ...@@ -71,6 +71,7 @@ class OrbitExperimentRunner:
controller_cls=orbit.Controller, controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None, summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None, eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
enable_async_checkpointing: bool = False,
): ):
"""Constructor. """Constructor.
...@@ -94,6 +95,8 @@ class OrbitExperimentRunner: ...@@ -94,6 +95,8 @@ class OrbitExperimentRunner:
summary manager. summary manager.
eval_summary_manager: Instance of the eval summary manager to override eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager. default eval summary manager.
enable_async_checkpointing: Optional boolean indicating whether to enable
async checkpoint saving.
""" """
self.strategy = distribution_strategy or tf.distribute.get_strategy() self.strategy = distribution_strategy or tf.distribute.get_strategy()
self._params = params self._params = params
...@@ -115,7 +118,8 @@ class OrbitExperimentRunner: ...@@ -115,7 +118,8 @@ class OrbitExperimentRunner:
save_summary=save_summary, save_summary=save_summary,
train_actions=train_actions, train_actions=train_actions,
eval_actions=eval_actions, eval_actions=eval_actions,
controller_cls=controller_cls) controller_cls=controller_cls,
enable_async_checkpointing=enable_async_checkpointing)
@property @property
def params(self) -> config_definitions.ExperimentConfig: def params(self) -> config_definitions.ExperimentConfig:
...@@ -188,13 +192,16 @@ class OrbitExperimentRunner: ...@@ -188,13 +192,16 @@ class OrbitExperimentRunner:
checkpoint_manager = None checkpoint_manager = None
return checkpoint_manager return checkpoint_manager
def _build_controller(self, def _build_controller(
trainer, self,
evaluator, trainer,
save_summary: bool = True, evaluator,
train_actions: Optional[List[orbit.Action]] = None, save_summary: bool = True,
eval_actions: Optional[List[orbit.Action]] = None, train_actions: Optional[List[orbit.Action]] = None,
controller_cls=orbit.Controller) -> orbit.Controller: eval_actions: Optional[List[orbit.Action]] = None,
controller_cls=orbit.Controller,
enable_async_checkpointing: bool = False,
) -> orbit.Controller:
"""Builds a Orbit controler.""" """Builds a Orbit controler."""
train_actions = [] if not train_actions else train_actions train_actions = [] if not train_actions else train_actions
if trainer: if trainer:
...@@ -223,6 +230,7 @@ class OrbitExperimentRunner: ...@@ -223,6 +230,7 @@ class OrbitExperimentRunner:
global_step=self.trainer.global_step, global_step=self.trainer.global_step,
steps_per_loop=self.params.trainer.steps_per_loop, steps_per_loop=self.params.trainer.steps_per_loop,
checkpoint_manager=self.checkpoint_manager, checkpoint_manager=self.checkpoint_manager,
enable_async_checkpointing=enable_async_checkpointing,
summary_dir=os.path.join(self.model_dir, 'train') summary_dir=os.path.join(self.model_dir, 'train')
if (save_summary) if (save_summary)
else None, else None,
...@@ -309,6 +317,7 @@ def run_experiment( ...@@ -309,6 +317,7 @@ def run_experiment(
controller_cls=orbit.Controller, controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None, summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None, eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
enable_async_checkpointing: bool = False,
) -> Tuple[tf.keras.Model, Mapping[str, Any]]: ) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params. """Runs train/eval configured by the experiment params.
...@@ -332,6 +341,8 @@ def run_experiment( ...@@ -332,6 +341,8 @@ def run_experiment(
manager. manager.
eval_summary_manager: Instance of the eval summary manager to override eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager. default eval summary manager.
enable_async_checkpointing: Optional boolean indicating whether to enable
async checkpoint saving.
Returns: Returns:
A 2-tuple of (model, eval_logs). A 2-tuple of (model, eval_logs).
...@@ -353,5 +364,6 @@ def run_experiment( ...@@ -353,5 +364,6 @@ def run_experiment(
controller_cls=controller_cls, controller_cls=controller_cls,
summary_manager=summary_manager, summary_manager=summary_manager,
eval_summary_manager=eval_summary_manager, eval_summary_manager=eval_summary_manager,
enable_async_checkpointing=enable_async_checkpointing,
) )
return runner.run() return runner.run()
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
"""Training utils.""" """Training utils."""
import dataclasses import dataclasses
import inspect import inspect
import json import json
...@@ -22,10 +23,12 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -22,10 +23,12 @@ from typing import Any, Callable, Dict, List, Optional, Union
from absl import logging from absl import logging
import gin import gin
import numpy as np
import orbit import orbit
import tensorflow as tf import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import # pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import ops
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import # pylint: enable=g-direct-tensorflow-import
from official.core import base_task from official.core import base_task
...@@ -564,3 +567,44 @@ def try_count_flops(model: Union[tf.Module, tf.keras.Model], ...@@ -564,3 +567,44 @@ def try_count_flops(model: Union[tf.Module, tf.keras.Model],
'reached before this run.', e) 'reached before this run.', e)
return None return None
return None return None
@ops.RegisterStatistics('Einsum', 'flops')
def _einsum_flops(graph, node):
"""Calculates the compute resources needed for Einsum."""
assert len(node.input) == 2
x_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
graph, node.input[0])
y_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
graph, node.input[1])
x_shape.assert_is_fully_defined()
y_shape.assert_is_fully_defined()
x_shape = x_shape.as_list()
y_shape = y_shape.as_list()
equation = str(node.attr['equation'])
equation = (
equation.replace('s:', '')
.replace('"', '')
.replace(' ', '')
.replace('\n', '')
)
x_str = equation.split(',')[0]
y_r_str = equation.split(',')[1]
y_str = y_r_str.split('->')[0]
r_str = y_r_str.split('->')[1]
shape_dic = {}
contracted = set()
for indice in x_str + y_str:
if indice in x_str:
indice_dim = x_shape[x_str.find(indice)]
elif indice in y_str:
indice_dim = y_shape[y_str.find(indice)]
else:
raise ValueError('indice {} not found in inputs'.format(indice))
shape_dic[indice] = indice_dim
if indice not in r_str:
contracted.add(indice)
madds = np.prod([shape_dic[indice] for indice in r_str]) * (
np.prod([shape_dic[indice] for indice in contracted]))
flops = 2 * madds
return ops.OpStats('flops', flops)
...@@ -38,6 +38,11 @@ flags.DEFINE_integer( ...@@ -38,6 +38,11 @@ flags.DEFINE_integer(
default=None, default=None,
help='The number of total training steps for the pretraining job.') help='The number of total training steps for the pretraining job.')
flags.DEFINE_bool(
'enable_async_checkpointing',
default=True,
help='A boolean indicating whether to enable async checkpoint saving')
def _run_experiment_with_preemption_recovery(params, model_dir): def _run_experiment_with_preemption_recovery(params, model_dir):
"""Runs experiment and tries to reconnect when encounting a preemption.""" """Runs experiment and tries to reconnect when encounting a preemption."""
...@@ -53,14 +58,17 @@ def _run_experiment_with_preemption_recovery(params, model_dir): ...@@ -53,14 +58,17 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
**params.runtime.model_parallelism()) **params.runtime.model_parallelism())
with distribution_strategy.scope(): with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir) task = task_factory.get_task(params.task, logging_dir=model_dir)
preemption_watcher = tf.distribute.experimental.PreemptionWatcher() # pylint: disable=line-too-long
preemption_watcher = None # copybara-replace
# pylint: enable=line-too-long
train_lib.run_experiment( train_lib.run_experiment(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
task=task, task=task,
mode=FLAGS.mode, mode=FLAGS.mode,
params=params, params=params,
model_dir=model_dir) model_dir=model_dir,
enable_async_checkpointing=FLAGS.enable_async_checkpointing)
keep_training = False keep_training = False
except tf.errors.OpError as e: except tf.errors.OpError as e:
......
...@@ -19,4 +19,4 @@ from official.projects.maxvit.configs import backbones # pylint:disable=unused- ...@@ -19,4 +19,4 @@ from official.projects.maxvit.configs import backbones # pylint:disable=unused-
from official.projects.maxvit.configs import rcnn # pylint:disable=unused-import from official.projects.maxvit.configs import rcnn # pylint:disable=unused-import
from official.projects.maxvit.configs import retinanet # pylint:disable=unused-import from official.projects.maxvit.configs import retinanet # pylint:disable=unused-import
from official.projects.maxvit.configs import semantic_segmentation # pylint:disable=unused-import from official.projects.maxvit.configs import semantic_segmentation # pylint:disable=unused-import
from official.projects.maxvit.configs.google import image_classification # pylint:disable=unused-import from official.projects.maxvit.configs import image_classification # pylint:disable=unused-import
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
# pylint: disable=g-bad-import-order
from official.vision import registry_imports
from official.projects.maxvit import configs # pylint: disable=unused-import
from official.projects.maxvit.modeling import maxvit # pylint: disable=unused-import
...@@ -12,12 +12,11 @@ ...@@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""TensorFlow Model Garden Vision training driver, including ViT configs..""" """TensorFlow Model Garden Vision training driver, including MaxViT configs.."""
from absl import app from absl import app
from official.common import flags as tfm_flags from official.common import flags as tfm_flags
from official.projects.maxvit import configs # pylint: disable=unused-import from official.projects.maxvit import registry_imports # pylint: disable=unused-import
from official.projects.maxvit.modeling import maxvit # pylint: disable=unused-import
from official.vision import train from official.vision import train
......
...@@ -114,7 +114,6 @@ def convert_predictions_to_coco_annotations(predictions): ...@@ -114,7 +114,6 @@ def convert_predictions_to_coco_annotations(predictions):
Required fields: Required fields:
- source_id: a list of numpy arrays of int or string of shape - source_id: a list of numpy arrays of int or string of shape
[batch_size]. [batch_size].
- num_detections: a list of numpy arrays of int of shape [batch_size].
- detection_boxes: a list of numpy arrays of float of shape - detection_boxes: a list of numpy arrays of float of shape
[batch_size, K, 4], where coordinates are in the original image [batch_size, K, 4], where coordinates are in the original image
space (not the scaled image space). space (not the scaled image space).
...@@ -125,6 +124,8 @@ def convert_predictions_to_coco_annotations(predictions): ...@@ -125,6 +124,8 @@ def convert_predictions_to_coco_annotations(predictions):
Optional fields: Optional fields:
- detection_masks: a list of numpy arrays of float of shape - detection_masks: a list of numpy arrays of float of shape
[batch_size, K, mask_height, mask_width]. [batch_size, K, mask_height, mask_width].
- detection_keypoints: a list of numpy arrays of float of shape
[batch_size, K, num_keypoints, 2]
Returns: Returns:
coco_predictions: prediction in COCO annotation format. coco_predictions: prediction in COCO annotation format.
...@@ -144,17 +145,32 @@ def convert_predictions_to_coco_annotations(predictions): ...@@ -144,17 +145,32 @@ def convert_predictions_to_coco_annotations(predictions):
mask_boxes = predictions['detection_boxes'] mask_boxes = predictions['detection_boxes']
batch_size = predictions['source_id'][i].shape[0] batch_size = predictions['source_id'][i].shape[0]
if 'detection_keypoints' in predictions:
# Adds extra ones to indicate the visibility for each keypoint as is
# recommended by MSCOCO. Also, convert keypoint from [y, x] to [x, y]
# as mandated by COCO.
num_keypoints = predictions['detection_keypoints'][i].shape[2]
coco_keypoints = np.concatenate(
[
predictions['detection_keypoints'][i][..., 1:],
predictions['detection_keypoints'][i][..., :1],
np.ones([batch_size, max_num_detections, num_keypoints, 1]),
],
axis=-1,
).astype(int)
for j in range(batch_size): for j in range(batch_size):
if 'detection_masks' in predictions: if 'detection_masks' in predictions:
image_masks = mask_ops.paste_instance_masks( image_masks = mask_ops.paste_instance_masks(
predictions['detection_masks'][i][j], predictions['detection_masks'][i][j],
mask_boxes[i][j], mask_boxes[i][j],
int(predictions['image_info'][i][j, 0, 0]), int(predictions['image_info'][i][j, 0, 0]),
int(predictions['image_info'][i][j, 0, 1])) int(predictions['image_info'][i][j, 0, 1]),
)
binary_masks = (image_masks > 0.0).astype(np.uint8) binary_masks = (image_masks > 0.0).astype(np.uint8)
encoded_masks = [ encoded_masks = [
mask_api.encode(np.asfortranarray(binary_mask)) mask_api.encode(np.asfortranarray(binary_mask))
for binary_mask in list(binary_masks)] for binary_mask in list(binary_masks)
]
for k in range(max_num_detections): for k in range(max_num_detections):
ann = {} ann = {}
ann['image_id'] = predictions['source_id'][i][j] ann['image_id'] = predictions['source_id'][i][j]
...@@ -164,21 +180,7 @@ def convert_predictions_to_coco_annotations(predictions): ...@@ -164,21 +180,7 @@ def convert_predictions_to_coco_annotations(predictions):
if 'detection_masks' in predictions: if 'detection_masks' in predictions:
ann['segmentation'] = encoded_masks[k] ann['segmentation'] = encoded_masks[k]
if 'detection_keypoints' in predictions: if 'detection_keypoints' in predictions:
# Adds extra ones to indicate the visibility for each keypoint as is ann['keypoints'] = coco_keypoints[j, k].flatten().tolist()
# recommended by MSCOCO. Also, convert keypoint from [y, x] to [x, y]
# as mandated by COCO.
instance_keypoints = predictions['detection_keypoints'][i][j, k]
num_keypoints = len(instance_keypoints)
instance_keypoints = np.concatenate(
[
np.expand_dims(instance_keypoints[:, 1], axis=-1),
np.expand_dims(instance_keypoints[:, 0], axis=-1),
np.expand_dims(np.ones(num_keypoints), axis=1),
],
axis=1,
).astype(int)
instance_keypoints = instance_keypoints.flatten().tolist()
ann['keypoints'] = instance_keypoints
coco_predictions.append(ann) coco_predictions.append(ann)
for i, ann in enumerate(coco_predictions): for i, ann in enumerate(coco_predictions):
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import os import os
import numpy as np
import tensorflow as tf import tensorflow as tf
from official.vision.dataloaders import tfexample_utils from official.vision.dataloaders import tfexample_utils
...@@ -27,11 +28,13 @@ class CocoUtilsTest(tf.test.TestCase): ...@@ -27,11 +28,13 @@ class CocoUtilsTest(tf.test.TestCase):
def test_scan_and_generator_annotation_file(self): def test_scan_and_generator_annotation_file(self):
num_samples = 10 num_samples = 10
example = tfexample_utils.create_detection_test_example( example = tfexample_utils.create_detection_test_example(
image_height=512, image_width=512, image_channel=3, num_instances=10) image_height=512, image_width=512, image_channel=3, num_instances=10
)
tf_examples = [example] * num_samples tf_examples = [example] * num_samples
data_file = os.path.join(self.create_tempdir(), 'test.tfrecord') data_file = os.path.join(self.create_tempdir(), 'test.tfrecord')
tfexample_utils.dump_to_tfrecord( tfexample_utils.dump_to_tfrecord(
record_file=data_file, tf_examples=tf_examples) record_file=data_file, tf_examples=tf_examples
)
annotation_file = os.path.join(self.create_tempdir(), 'annotation.json') annotation_file = os.path.join(self.create_tempdir(), 'annotation.json')
coco_utils.scan_and_generator_annotation_file( coco_utils.scan_and_generator_annotation_file(
...@@ -39,10 +42,53 @@ class CocoUtilsTest(tf.test.TestCase): ...@@ -39,10 +42,53 @@ class CocoUtilsTest(tf.test.TestCase):
file_type='tfrecord', file_type='tfrecord',
num_samples=num_samples, num_samples=num_samples,
include_mask=True, include_mask=True,
annotation_file=annotation_file) annotation_file=annotation_file,
)
self.assertTrue( self.assertTrue(
tf.io.gfile.exists(annotation_file), tf.io.gfile.exists(annotation_file),
msg='Annotation file {annotation_file} does not exists.') msg='Annotation file {annotation_file} does not exist.',
)
def test_convert_keypoint_predictions_to_coco_annotations(self):
batch_size = 1
max_num_detections = 3
num_keypoints = 3
image_size = 512
source_id = [np.array([[1]], dtype=int)]
detection_boxes = [
np.random.random([batch_size, max_num_detections, 4]) * image_size
]
detection_class = [
np.random.randint(1, 5, [batch_size, max_num_detections])
]
detection_scores = [np.random.random([batch_size, max_num_detections])]
detection_keypoints = [
np.random.random([batch_size, max_num_detections, num_keypoints, 2])
* image_size
]
predictions = {
'source_id': source_id,
'detection_boxes': detection_boxes,
'detection_classes': detection_class,
'detection_scores': detection_scores,
'detection_keypoints': detection_keypoints,
}
anns = coco_utils.convert_predictions_to_coco_annotations(predictions)
for i in range(max_num_detections):
expected_keypoint_ann = np.concatenate(
[
np.expand_dims(detection_keypoints[0][0, i, :, 1], axis=-1),
np.expand_dims(detection_keypoints[0][0, i, :, 0], axis=-1),
np.expand_dims(np.ones(num_keypoints), axis=1),
],
axis=1,
).astype(int)
expected_keypoint_ann = expected_keypoint_ann.flatten().tolist()
self.assertAllEqual(anns[i]['keypoints'], expected_keypoint_ann)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -32,6 +32,11 @@ from official.vision.utils import summary_manager ...@@ -32,6 +32,11 @@ from official.vision.utils import summary_manager
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_bool(
'enable_async_checkpointing',
default=True,
help='A boolean indicating whether to enable async checkpoint saving')
def _run_experiment_with_preemption_recovery(params, model_dir): def _run_experiment_with_preemption_recovery(params, model_dir):
"""Runs experiment and tries to reconnect when encounting a preemption.""" """Runs experiment and tries to reconnect when encounting a preemption."""
...@@ -46,7 +51,9 @@ def _run_experiment_with_preemption_recovery(params, model_dir): ...@@ -46,7 +51,9 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
tpu_address=params.runtime.tpu) tpu_address=params.runtime.tpu)
with distribution_strategy.scope(): with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir) task = task_factory.get_task(params.task, logging_dir=model_dir)
preemption_watcher = tf.distribute.experimental.PreemptionWatcher() # pylint: disable=line-too-long
preemption_watcher = None # copybara-replace
# pylint: enable=line-too-long
train_lib.run_experiment( train_lib.run_experiment(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
...@@ -58,6 +65,7 @@ def _run_experiment_with_preemption_recovery(params, model_dir): ...@@ -58,6 +65,7 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
eval_summary_manager=summary_manager.maybe_build_eval_summary_manager( eval_summary_manager=summary_manager.maybe_build_eval_summary_manager(
params=params, model_dir=model_dir params=params, model_dir=model_dir
), ),
enable_async_checkpointing=FLAGS.enable_async_checkpointing,
) )
keep_training = False keep_training = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册