提交 8f07b9ce 编写于 作者: C Chaochao Yan 提交者: A. Unique TensorFlower

Internal change

PiperOrigin-RevId: 524957164
上级 849f5268
......@@ -284,7 +284,7 @@ class Decoder(decoder.Decoder):
self._context_features[name] = feature_type
else:
raise ValueError(
f"Unknow feature source {self._feature_sources[i]} for {name}")
f"Unknown feature source {self._feature_sources[i]} for {name}")
def _add_labels_specification(self):
if not self._label_field:
......@@ -333,6 +333,7 @@ class Parser(parser.Parser):
self._segment_labels = input_params.segment_labels
self._include_video_id = input_params.include_video_id
self._feature_names = input_params.feature_names
self._feature_sources = input_params.feature_sources
self._feature_sizes = input_params.feature_sizes
self._feature_dtypes = input_params.feature_dtypes
self._max_frames = input_params.max_frames
......@@ -398,6 +399,27 @@ class Parser(parser.Parser):
def parse(decoded_tensors):
"""Parses the serialized example data."""
# Concatenate video features to all frames if there are both video-level
# (context) and frame-level (feature) features.
if "feature" in self._feature_sources:
# Take first frame feature matrix, any feature matrix should be fine
# since assume all frame features have same number of frames.
feature_idx = self._feature_sources.index("feature")
num_frames = tf.shape(
decoded_tensors[self._feature_names[feature_idx]]
)[0]
for feature_idx, feature_source in enumerate(self._feature_sources):
if feature_source == "context":
feature_name = self._feature_names[feature_idx]
context_tensor = tf.reshape(
decoded_tensors[feature_name],
shape=(1, self._feature_sizes[feature_idx]),
)
decoded_tensors[feature_name] = tf.tile(
context_tensor, [num_frames, 1]
)
if is_training:
return self._parse_train_data(decoded_tensors)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册