提交 2b360b0e 编写于 作者: D Dan Smilkov 提交者: TensorFlower Gardener

Add support for user-generated tsv tensor files in the embedding projector.

Change: 137954868
上级 c66f3a4d
......@@ -28,11 +28,12 @@ from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 impor
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.pywrap_tensorflow import NewCheckpointReader
from tensorflow.python.training.saver import checkpoint_exists
from tensorflow.python.training.saver import latest_checkpoint
from tensorflow.tensorboard.plugins.base_plugin import TBPlugin
# HTTP routes.
INFO_ROUTE = '/info'
CONFIG_ROUTE = '/info'
TENSOR_ROUTE = '/tensor'
METADATA_ROUTE = '/metadata'
RUNS_ROUTE = '/runs'
......@@ -51,6 +52,15 @@ _IMGHDR_TO_MIMETYPE = {
_DEFAULT_IMAGE_MIMETYPE = 'application/octet-stream'
def _read_tensor_file(fpath):
with file_io.FileIO(fpath, 'r') as f:
tensor = []
for line in f:
if line:
tensor.append(line.rstrip('\n').split('\t'))
return tensor
class ProjectorPlugin(TBPlugin):
"""Embedding projector."""
......@@ -58,16 +68,47 @@ class ProjectorPlugin(TBPlugin):
self.configs, self.config_fpaths = self._read_config_files(run_paths,
logdir)
self.readers = {}
self._augment_configs_with_checkpoint_info()
return {
RUNS_ROUTE: self._serve_runs,
INFO_ROUTE: self._serve_info,
CONFIG_ROUTE: self._serve_config,
TENSOR_ROUTE: self._serve_tensor,
METADATA_ROUTE: self._serve_metadata,
BOOKMARKS_ROUTE: self._serve_bookmarks,
SPRITE_IMAGE_ROUTE: self._serve_sprite_image
}
def _augment_configs_with_checkpoint_info(self):
for run, config in self.configs.items():
# Find the size of the embeddings that are associated with a tensor file.
for embedding in config.embeddings:
if embedding.tensor_path and not embedding.tensor_shape:
tensor = _read_tensor_file(embedding.tensor_path)
embedding.tensor_shape.extend([len(tensor), len(tensor[0])])
reader = self._get_reader_for_run(run)
if not reader:
continue
# Augment the configuration with the tensors in the checkpoint file.
special_embedding = None
if config.embeddings and not config.embeddings[0].tensor_name:
special_embedding = config.embeddings[0]
config.embeddings.remove(special_embedding)
var_map = reader.get_variable_to_shape_map()
for tensor_name, tensor_shape in var_map.items():
if len(tensor_shape) != 2:
continue
embedding = self._get_embedding(tensor_name, config)
if not embedding:
embedding = config.embeddings.add()
embedding.tensor_name = tensor_name
if special_embedding:
embedding.metadata_path = special_embedding.metadata_path
embedding.bookmarks_path = special_embedding.bookmarks_path
if not embedding.tensor_shape:
embedding.tensor_shape.extend(tensor_shape)
def _read_config_files(self, run_paths, logdir):
# If there are no summary event files, the projector can still work,
# thus treating the `logdir` as the model checkpoint directory.
......@@ -77,28 +118,33 @@ class ProjectorPlugin(TBPlugin):
configs = {}
config_fpaths = {}
for run_name, logdir in run_paths.items():
config = ProjectorConfig()
config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
if not file_io.file_exists(config_fpath):
# Skip runs that have no config file.
continue
# Read the config file.
if file_io.file_exists(config_fpath):
file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
config = ProjectorConfig()
text_format.Merge(file_content, config)
has_tensor_files = False
for embedding in config.embeddings:
if embedding.tensor_path:
has_tensor_files = True
break
if not config.model_checkpoint_path:
# See if you can find a checkpoint file in the logdir.
ckpt_path = latest_checkpoint(logdir)
if not ckpt_path:
# Or in the parent of logdir.
ckpt_path = latest_checkpoint(os.path.join('../', logdir))
if not ckpt_path:
if not ckpt_path and not has_tensor_files:
logging.warning('Cannot find model checkpoint in %s', logdir)
continue
if ckpt_path:
config.model_checkpoint_path = ckpt_path
# Sanity check for the checkpoint file.
if not file_io.file_exists(config.model_checkpoint_path):
if (config.model_checkpoint_path and
not checkpoint_exists(config.model_checkpoint_path)):
logging.warning('Checkpoint file %s not found',
config.model_checkpoint_path)
continue
......@@ -111,18 +157,20 @@ class ProjectorPlugin(TBPlugin):
return self.readers[run]
config = self.configs[run]
reader = None
if config.model_checkpoint_path:
reader = NewCheckpointReader(config.model_checkpoint_path)
self.readers[run] = reader
return reader
def _get_metadata_file_for_tensor(self, tensor_name, config):
embedding_info = self._get_embedding_info_for_tensor(tensor_name, config)
embedding_info = self._get_embedding(tensor_name, config)
if embedding_info:
return embedding_info.metadata_path
return None
def _get_bookmarks_file_for_tensor(self, tensor_name, config):
embedding_info = self._get_embedding_info_for_tensor(tensor_name, config)
embedding_info = self._get_embedding(tensor_name, config)
if embedding_info:
return embedding_info.bookmarks_path
return None
......@@ -133,7 +181,7 @@ class ProjectorPlugin(TBPlugin):
else:
return tensor_name
def _get_embedding_info_for_tensor(self, tensor_name, config):
def _get_embedding(self, tensor_name, config):
if not config.embeddings:
return None
for info in config.embeddings:
......@@ -146,7 +194,7 @@ class ProjectorPlugin(TBPlugin):
"""Returns a list of runs that have embeddings."""
self.handler.respond(list(self.configs.keys()), 'application/json')
def _serve_info(self, query_params):
def _serve_config(self, query_params):
run = query_params.get('run')
if run is None:
self.handler.respond('query parameter "run" is required',
......@@ -157,19 +205,6 @@ class ProjectorPlugin(TBPlugin):
return
config = self.configs[run]
reader = self._get_reader_for_run(run)
var_map = reader.get_variable_to_shape_map()
for tensor_name, tensor_shape in var_map.items():
if len(tensor_shape) != 2:
continue
info = self._get_embedding_info_for_tensor(tensor_name, config)
if not info:
info = config.embeddings.add()
info.tensor_name = tensor_name
if not info.tensor_shape:
info.tensor_shape.extend(tensor_shape)
self.handler.respond(json_format.MessageToJson(config), 'application/json')
def _serve_metadata(self, query_params):
......@@ -192,7 +227,7 @@ class ProjectorPlugin(TBPlugin):
fpath = self._get_metadata_file_for_tensor(name, config)
if not fpath:
self.handler.respond(
'Not metadata file found for tensor %s in the config file %s' %
'No metadata file found for tensor %s in the config file %s' %
(name, self.config_fpaths[run]), 'text/plain', 400)
return
if not file_io.file_exists(fpath) or file_io.is_directory(fpath):
......@@ -231,12 +266,27 @@ class ProjectorPlugin(TBPlugin):
reader = self._get_reader_for_run(run)
config = self.configs[run]
if reader is None:
# See if there is a tensor file in the config.
embedding = self._get_embedding(name, config)
if not embedding or not embedding.tensor_path:
self.handler.respond('Tensor %s has no tensor_path in the config' %
name, 'text/plain', 400)
return
if not file_io.file_exists(embedding.tensor_path):
self.handler.respond('Tensor file %s does not exist' %
embedding.tensor_path, 'text/plain', 400)
return
tensor = _read_tensor_file(embedding.tensor_path)
else:
if not reader.has_tensor(name):
self.handler.respond('Tensor %s not found in checkpoint dir %s' %
(name, config.model_checkpoint_path),
'text/plain', 400)
return
tensor = reader.get_tensor(name)
# Sample the tensor
tensor = tensor[:LIMIT_NUM_POINTS]
# Stream it as TSV.
......@@ -294,7 +344,7 @@ class ProjectorPlugin(TBPlugin):
return
config = self.configs[run]
embedding_info = self._get_embedding_info_for_tensor(name, config)
embedding_info = self._get_embedding(name, config)
if not embedding_info or not embedding_info.sprite.image_path:
self.handler.respond(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册