提交 5d38628c 编写于 作者: Y Yeqing Li 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 457890217
上级 cfcbb6cb
......@@ -88,6 +88,7 @@ class DataConfig(cfg.DataConfig):
def yt8m(is_training):
"""YT8M dataset configs."""
# pylint: disable=unexpected-keyword-arg
return DataConfig(
num_frames=30,
temporal_stride=1,
......@@ -95,8 +96,10 @@ def yt8m(is_training):
segment_size=5,
is_training=is_training,
split='train' if is_training else 'valid',
drop_remainder=is_training, # pytype: disable=wrong-keyword-args
num_examples=YT8M_TRAIN_EXAMPLES if is_training else YT8M_VAL_EXAMPLES,
input_path=YT8M_TRAIN_PATH if is_training else YT8M_VAL_PATH)
# pylint: enable=unexpected-keyword-arg
@dataclasses.dataclass
......
......@@ -22,7 +22,6 @@
back into a range between min_quantized_value and max_quantized_value.
link for details: https://research.google.com/youtube8m/download.html
"""
from typing import Dict
import tensorflow as tf
......@@ -424,6 +423,7 @@ class PostBatchProcessor():
[-1, self.num_classes])
else:
# NOTE(b/237445211): Must provide axis argument to tf.squeeze.
video_matrix = tf.squeeze(video_matrix, axis=1)
labels = tf.squeeze(labels, axis=1)
......@@ -449,13 +449,15 @@ class TransformBatcher():
self._global_batch_size = input_params.global_batch_size
self._is_training = input_params.is_training
self._include_video_id = input_params.include_video_id
self._drop_remainder = input_params.drop_remainder
def batch_fn(self, dataset, input_context):
"""Add padding when segment_labels is true."""
per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size
if not self._segment_labels:
dataset = dataset.batch(per_replica_batch_size, drop_remainder=True)
dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder)
else:
# add padding
pad_shapes = {
......@@ -476,6 +478,6 @@ class TransformBatcher():
dataset = dataset.padded_batch(
per_replica_batch_size,
padded_shapes=pad_shapes,
drop_remainder=True,
drop_remainder=self._drop_remainder,
padding_values=pad_values)
return dataset
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册