未验证 提交 7313d7a9 编写于 作者: Z Zihan Wang 提交者: GitHub

Merge branch 'tensorflow:master' into master

......@@ -186,7 +186,7 @@
"exp_config = tfm.core.exp_factory.get_exp_config('resnet_imagenet')\n",
"tfds_name = 'cifar10'\n",
"ds,ds_info = tfds.load(\n",
"tfds_name\n",
"tfds_name,\n",
"with_info=True)\n",
"ds_info"
]
......
......@@ -71,6 +71,7 @@ class OrbitExperimentRunner:
controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
enable_async_checkpointing: bool = False,
):
"""Constructor.
......@@ -94,6 +95,8 @@ class OrbitExperimentRunner:
summary manager.
eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager.
enable_async_checkpointing: Optional boolean indicating whether to enable
async checkpoint saving.
"""
self.strategy = distribution_strategy or tf.distribute.get_strategy()
self._params = params
......@@ -115,7 +118,8 @@ class OrbitExperimentRunner:
save_summary=save_summary,
train_actions=train_actions,
eval_actions=eval_actions,
controller_cls=controller_cls)
controller_cls=controller_cls,
enable_async_checkpointing=enable_async_checkpointing)
@property
def params(self) -> config_definitions.ExperimentConfig:
......@@ -188,13 +192,16 @@ class OrbitExperimentRunner:
checkpoint_manager = None
return checkpoint_manager
def _build_controller(self,
def _build_controller(
self,
trainer,
evaluator,
save_summary: bool = True,
train_actions: Optional[List[orbit.Action]] = None,
eval_actions: Optional[List[orbit.Action]] = None,
controller_cls=orbit.Controller) -> orbit.Controller:
controller_cls=orbit.Controller,
enable_async_checkpointing: bool = False,
) -> orbit.Controller:
"""Builds a Orbit controler."""
train_actions = [] if not train_actions else train_actions
if trainer:
......@@ -223,6 +230,7 @@ class OrbitExperimentRunner:
global_step=self.trainer.global_step,
steps_per_loop=self.params.trainer.steps_per_loop,
checkpoint_manager=self.checkpoint_manager,
enable_async_checkpointing=enable_async_checkpointing,
summary_dir=os.path.join(self.model_dir, 'train')
if (save_summary)
else None,
......@@ -309,6 +317,7 @@ def run_experiment(
controller_cls=orbit.Controller,
summary_manager: Optional[orbit.utils.SummaryManager] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManager] = None,
enable_async_checkpointing: bool = False,
) -> Tuple[tf.keras.Model, Mapping[str, Any]]:
"""Runs train/eval configured by the experiment params.
......@@ -332,6 +341,8 @@ def run_experiment(
manager.
eval_summary_manager: Instance of the eval summary manager to override
default eval summary manager.
enable_async_checkpointing: Optional boolean indicating whether to enable
async checkpoint saving.
Returns:
A 2-tuple of (model, eval_logs).
......@@ -353,5 +364,6 @@ def run_experiment(
controller_cls=controller_cls,
summary_manager=summary_manager,
eval_summary_manager=eval_summary_manager,
enable_async_checkpointing=enable_async_checkpointing,
)
return runner.run()
......@@ -13,6 +13,7 @@
# limitations under the License.
"""Training utils."""
import dataclasses
import inspect
import json
......@@ -22,10 +23,12 @@ from typing import Any, Callable, Dict, List, Optional, Union
from absl import logging
import gin
import numpy as np
import orbit
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import ops
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
from official.core import base_task
......@@ -564,3 +567,44 @@ def try_count_flops(model: Union[tf.Module, tf.keras.Model],
'reached before this run.', e)
return None
return None
@ops.RegisterStatistics('Einsum', 'flops')
def _einsum_flops(graph, node):
"""Calculates the compute resources needed for Einsum."""
assert len(node.input) == 2
x_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
graph, node.input[0])
y_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
graph, node.input[1])
x_shape.assert_is_fully_defined()
y_shape.assert_is_fully_defined()
x_shape = x_shape.as_list()
y_shape = y_shape.as_list()
equation = str(node.attr['equation'])
equation = (
equation.replace('s:', '')
.replace('"', '')
.replace(' ', '')
.replace('\n', '')
)
x_str = equation.split(',')[0]
y_r_str = equation.split(',')[1]
y_str = y_r_str.split('->')[0]
r_str = y_r_str.split('->')[1]
shape_dic = {}
contracted = set()
for indice in x_str + y_str:
if indice in x_str:
indice_dim = x_shape[x_str.find(indice)]
elif indice in y_str:
indice_dim = y_shape[y_str.find(indice)]
else:
raise ValueError('indice {} not found in inputs'.format(indice))
shape_dic[indice] = indice_dim
if indice not in r_str:
contracted.add(indice)
madds = np.prod([shape_dic[indice] for indice in r_str]) * (
np.prod([shape_dic[indice] for indice in contracted]))
flops = 2 * madds
return ops.OpStats('flops', flops)
......@@ -38,6 +38,11 @@ flags.DEFINE_integer(
default=None,
help='The number of total training steps for the pretraining job.')
flags.DEFINE_bool(
'enable_async_checkpointing',
default=True,
help='A boolean indicating whether to enable async checkpoint saving')
def _run_experiment_with_preemption_recovery(params, model_dir):
"""Runs experiment and tries to reconnect when encounting a preemption."""
......@@ -53,14 +58,17 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
**params.runtime.model_parallelism())
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
preemption_watcher = tf.distribute.experimental.PreemptionWatcher()
# pylint: disable=line-too-long
preemption_watcher = None # copybara-replace
# pylint: enable=line-too-long
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
model_dir=model_dir,
enable_async_checkpointing=FLAGS.enable_async_checkpointing)
keep_training = False
except tf.errors.OpError as e:
......
......@@ -19,4 +19,4 @@ from official.projects.maxvit.configs import backbones # pylint:disable=unused-
from official.projects.maxvit.configs import rcnn # pylint:disable=unused-import
from official.projects.maxvit.configs import retinanet # pylint:disable=unused-import
from official.projects.maxvit.configs import semantic_segmentation # pylint:disable=unused-import
from official.projects.maxvit.configs.google import image_classification # pylint:disable=unused-import
from official.projects.maxvit.configs import image_classification # pylint:disable=unused-import
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All necessary imports for registration."""
# pylint: disable=unused-import
# pylint: disable=g-bad-import-order
from official.vision import registry_imports
from official.projects.maxvit import configs # pylint: disable=unused-import
from official.projects.maxvit.modeling import maxvit # pylint: disable=unused-import
......@@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Vision training driver, including ViT configs.."""
"""TensorFlow Model Garden Vision training driver, including MaxViT configs.."""
from absl import app
from official.common import flags as tfm_flags
from official.projects.maxvit import configs # pylint: disable=unused-import
from official.projects.maxvit.modeling import maxvit # pylint: disable=unused-import
from official.projects.maxvit import registry_imports # pylint: disable=unused-import
from official.vision import train
......
......@@ -114,7 +114,6 @@ def convert_predictions_to_coco_annotations(predictions):
Required fields:
- source_id: a list of numpy arrays of int or string of shape
[batch_size].
- num_detections: a list of numpy arrays of int of shape [batch_size].
- detection_boxes: a list of numpy arrays of float of shape
[batch_size, K, 4], where coordinates are in the original image
space (not the scaled image space).
......@@ -125,6 +124,8 @@ def convert_predictions_to_coco_annotations(predictions):
Optional fields:
- detection_masks: a list of numpy arrays of float of shape
[batch_size, K, mask_height, mask_width].
- detection_keypoints: a list of numpy arrays of float of shape
[batch_size, K, num_keypoints, 2]
Returns:
coco_predictions: prediction in COCO annotation format.
......@@ -144,17 +145,32 @@ def convert_predictions_to_coco_annotations(predictions):
mask_boxes = predictions['detection_boxes']
batch_size = predictions['source_id'][i].shape[0]
if 'detection_keypoints' in predictions:
# Adds extra ones to indicate the visibility for each keypoint as is
# recommended by MSCOCO. Also, convert keypoint from [y, x] to [x, y]
# as mandated by COCO.
num_keypoints = predictions['detection_keypoints'][i].shape[2]
coco_keypoints = np.concatenate(
[
predictions['detection_keypoints'][i][..., 1:],
predictions['detection_keypoints'][i][..., :1],
np.ones([batch_size, max_num_detections, num_keypoints, 1]),
],
axis=-1,
).astype(int)
for j in range(batch_size):
if 'detection_masks' in predictions:
image_masks = mask_ops.paste_instance_masks(
predictions['detection_masks'][i][j],
mask_boxes[i][j],
int(predictions['image_info'][i][j, 0, 0]),
int(predictions['image_info'][i][j, 0, 1]))
int(predictions['image_info'][i][j, 0, 1]),
)
binary_masks = (image_masks > 0.0).astype(np.uint8)
encoded_masks = [
mask_api.encode(np.asfortranarray(binary_mask))
for binary_mask in list(binary_masks)]
for binary_mask in list(binary_masks)
]
for k in range(max_num_detections):
ann = {}
ann['image_id'] = predictions['source_id'][i][j]
......@@ -164,21 +180,7 @@ def convert_predictions_to_coco_annotations(predictions):
if 'detection_masks' in predictions:
ann['segmentation'] = encoded_masks[k]
if 'detection_keypoints' in predictions:
# Adds extra ones to indicate the visibility for each keypoint as is
# recommended by MSCOCO. Also, convert keypoint from [y, x] to [x, y]
# as mandated by COCO.
instance_keypoints = predictions['detection_keypoints'][i][j, k]
num_keypoints = len(instance_keypoints)
instance_keypoints = np.concatenate(
[
np.expand_dims(instance_keypoints[:, 1], axis=-1),
np.expand_dims(instance_keypoints[:, 0], axis=-1),
np.expand_dims(np.ones(num_keypoints), axis=1),
],
axis=1,
).astype(int)
instance_keypoints = instance_keypoints.flatten().tolist()
ann['keypoints'] = instance_keypoints
ann['keypoints'] = coco_keypoints[j, k].flatten().tolist()
coco_predictions.append(ann)
for i, ann in enumerate(coco_predictions):
......
......@@ -16,6 +16,7 @@
import os
import numpy as np
import tensorflow as tf
from official.vision.dataloaders import tfexample_utils
......@@ -27,11 +28,13 @@ class CocoUtilsTest(tf.test.TestCase):
def test_scan_and_generator_annotation_file(self):
num_samples = 10
example = tfexample_utils.create_detection_test_example(
image_height=512, image_width=512, image_channel=3, num_instances=10)
image_height=512, image_width=512, image_channel=3, num_instances=10
)
tf_examples = [example] * num_samples
data_file = os.path.join(self.create_tempdir(), 'test.tfrecord')
tfexample_utils.dump_to_tfrecord(
record_file=data_file, tf_examples=tf_examples)
record_file=data_file, tf_examples=tf_examples
)
annotation_file = os.path.join(self.create_tempdir(), 'annotation.json')
coco_utils.scan_and_generator_annotation_file(
......@@ -39,10 +42,53 @@ class CocoUtilsTest(tf.test.TestCase):
file_type='tfrecord',
num_samples=num_samples,
include_mask=True,
annotation_file=annotation_file)
annotation_file=annotation_file,
)
self.assertTrue(
tf.io.gfile.exists(annotation_file),
msg='Annotation file {annotation_file} does not exists.')
msg='Annotation file {annotation_file} does not exist.',
)
def test_convert_keypoint_predictions_to_coco_annotations(self):
batch_size = 1
max_num_detections = 3
num_keypoints = 3
image_size = 512
source_id = [np.array([[1]], dtype=int)]
detection_boxes = [
np.random.random([batch_size, max_num_detections, 4]) * image_size
]
detection_class = [
np.random.randint(1, 5, [batch_size, max_num_detections])
]
detection_scores = [np.random.random([batch_size, max_num_detections])]
detection_keypoints = [
np.random.random([batch_size, max_num_detections, num_keypoints, 2])
* image_size
]
predictions = {
'source_id': source_id,
'detection_boxes': detection_boxes,
'detection_classes': detection_class,
'detection_scores': detection_scores,
'detection_keypoints': detection_keypoints,
}
anns = coco_utils.convert_predictions_to_coco_annotations(predictions)
for i in range(max_num_detections):
expected_keypoint_ann = np.concatenate(
[
np.expand_dims(detection_keypoints[0][0, i, :, 1], axis=-1),
np.expand_dims(detection_keypoints[0][0, i, :, 0], axis=-1),
np.expand_dims(np.ones(num_keypoints), axis=1),
],
axis=1,
).astype(int)
expected_keypoint_ann = expected_keypoint_ann.flatten().tolist()
self.assertAllEqual(anns[i]['keypoints'], expected_keypoint_ann)
if __name__ == '__main__':
......
......@@ -32,6 +32,11 @@ from official.vision.utils import summary_manager
FLAGS = flags.FLAGS
flags.DEFINE_bool(
'enable_async_checkpointing',
default=True,
help='A boolean indicating whether to enable async checkpoint saving')
def _run_experiment_with_preemption_recovery(params, model_dir):
"""Runs experiment and tries to reconnect when encounting a preemption."""
......@@ -46,7 +51,9 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
task = task_factory.get_task(params.task, logging_dir=model_dir)
preemption_watcher = tf.distribute.experimental.PreemptionWatcher()
# pylint: disable=line-too-long
preemption_watcher = None # copybara-replace
# pylint: enable=line-too-long
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
......@@ -58,6 +65,7 @@ def _run_experiment_with_preemption_recovery(params, model_dir):
eval_summary_manager=summary_manager.maybe_build_eval_summary_manager(
params=params, model_dir=model_dir
),
enable_async_checkpointing=FLAGS.enable_async_checkpointing,
)
keep_training = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册