From 8f07b9cef2d2c308d651663995d42ea0dd23c44a Mon Sep 17 00:00:00 2001 From: Chaochao Yan Date: Mon, 17 Apr 2023 14:53:54 -0700 Subject: [PATCH] Internal change PiperOrigin-RevId: 524957164 --- .../projects/yt8m/dataloaders/yt8m_input.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/official/projects/yt8m/dataloaders/yt8m_input.py b/official/projects/yt8m/dataloaders/yt8m_input.py index c4a789d08..9e6d34ac3 100644 --- a/official/projects/yt8m/dataloaders/yt8m_input.py +++ b/official/projects/yt8m/dataloaders/yt8m_input.py @@ -284,7 +284,7 @@ class Decoder(decoder.Decoder): self._context_features[name] = feature_type else: raise ValueError( - f"Unknow feature source {self._feature_sources[i]} for {name}") + f"Unknown feature source {self._feature_sources[i]} for {name}") def _add_labels_specification(self): if not self._label_field: @@ -333,6 +333,7 @@ class Parser(parser.Parser): self._segment_labels = input_params.segment_labels self._include_video_id = input_params.include_video_id self._feature_names = input_params.feature_names + self._feature_sources = input_params.feature_sources self._feature_sizes = input_params.feature_sizes self._feature_dtypes = input_params.feature_dtypes self._max_frames = input_params.max_frames @@ -398,6 +399,27 @@ class Parser(parser.Parser): def parse(decoded_tensors): """Parses the serialized example data.""" + + # Concatenate video features to all frames if there are both video-level + # (context) and frame-level (feature) features. + if "feature" in self._feature_sources: + # Take first frame feature matrix, any feature matrix should be fine + # since assume all frame features have same number of frames. + feature_idx = self._feature_sources.index("feature") + num_frames = tf.shape( + decoded_tensors[self._feature_names[feature_idx]] + )[0] + for feature_idx, feature_source in enumerate(self._feature_sources): + if feature_source == "context": + feature_name = self._feature_names[feature_idx] + context_tensor = tf.reshape( + decoded_tensors[feature_name], + shape=(1, self._feature_sizes[feature_idx]), + ) + decoded_tensors[feature_name] = tf.tile( + context_tensor, [num_frames, 1] + ) + if is_training: return self._parse_train_data(decoded_tensors) else: -- GitLab