From 2b360b0e6d1ba56cc94c08f3706d80957089ebda Mon Sep 17 00:00:00 2001 From: Dan Smilkov Date: Wed, 2 Nov 2016 08:32:11 -0800 Subject: [PATCH] Add support for user-generated tsv tensor files in the embedding projector. Change: 137954868 --- .../tensorboard/plugins/projector/plugin.py | 126 ++++++++++++------ 1 file changed, 88 insertions(+), 38 deletions(-) diff --git a/tensorflow/tensorboard/plugins/projector/plugin.py b/tensorflow/tensorboard/plugins/projector/plugin.py index b6d19650943..4b0c36ce5eb 100644 --- a/tensorflow/tensorboard/plugins/projector/plugin.py +++ b/tensorflow/tensorboard/plugins/projector/plugin.py @@ -28,11 +28,12 @@ from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 impor from tensorflow.python.lib.io import file_io from tensorflow.python.platform import tf_logging as logging from tensorflow.python.pywrap_tensorflow import NewCheckpointReader +from tensorflow.python.training.saver import checkpoint_exists from tensorflow.python.training.saver import latest_checkpoint from tensorflow.tensorboard.plugins.base_plugin import TBPlugin # HTTP routes. -INFO_ROUTE = '/info' +CONFIG_ROUTE = '/info' TENSOR_ROUTE = '/tensor' METADATA_ROUTE = '/metadata' RUNS_ROUTE = '/runs' @@ -51,6 +52,15 @@ _IMGHDR_TO_MIMETYPE = { _DEFAULT_IMAGE_MIMETYPE = 'application/octet-stream' +def _read_tensor_file(fpath): + with file_io.FileIO(fpath, 'r') as f: + tensor = [] + for line in f: + if line: + tensor.append(line.rstrip('\n').split('\t')) + return tensor + + class ProjectorPlugin(TBPlugin): """Embedding projector.""" @@ -58,16 +68,47 @@ class ProjectorPlugin(TBPlugin): self.configs, self.config_fpaths = self._read_config_files(run_paths, logdir) self.readers = {} + self._augment_configs_with_checkpoint_info() return { RUNS_ROUTE: self._serve_runs, - INFO_ROUTE: self._serve_info, + CONFIG_ROUTE: self._serve_config, TENSOR_ROUTE: self._serve_tensor, METADATA_ROUTE: self._serve_metadata, BOOKMARKS_ROUTE: self._serve_bookmarks, SPRITE_IMAGE_ROUTE: self._serve_sprite_image } + def _augment_configs_with_checkpoint_info(self): + for run, config in self.configs.items(): + # Find the size of the embeddings that are associated with a tensor file. + for embedding in config.embeddings: + if embedding.tensor_path and not embedding.tensor_shape: + tensor = _read_tensor_file(embedding.tensor_path) + embedding.tensor_shape.extend([len(tensor), len(tensor[0])]) + + reader = self._get_reader_for_run(run) + if not reader: + continue + # Augment the configuration with the tensors in the checkpoint file. + special_embedding = None + if config.embeddings and not config.embeddings[0].tensor_name: + special_embedding = config.embeddings[0] + config.embeddings.remove(special_embedding) + var_map = reader.get_variable_to_shape_map() + for tensor_name, tensor_shape in var_map.items(): + if len(tensor_shape) != 2: + continue + embedding = self._get_embedding(tensor_name, config) + if not embedding: + embedding = config.embeddings.add() + embedding.tensor_name = tensor_name + if special_embedding: + embedding.metadata_path = special_embedding.metadata_path + embedding.bookmarks_path = special_embedding.bookmarks_path + if not embedding.tensor_shape: + embedding.tensor_shape.extend(tensor_shape) + def _read_config_files(self, run_paths, logdir): # If there are no summary event files, the projector can still work, # thus treating the `logdir` as the model checkpoint directory. @@ -77,14 +118,17 @@ class ProjectorPlugin(TBPlugin): configs = {} config_fpaths = {} for run_name, logdir in run_paths.items(): - config_fpath = os.path.join(logdir, PROJECTOR_FILENAME) - if not file_io.file_exists(config_fpath): - # Skip runs that have no config file. - continue - # Read the config file. - file_content = file_io.read_file_to_string(config_fpath).decode('utf-8') config = ProjectorConfig() - text_format.Merge(file_content, config) + config_fpath = os.path.join(logdir, PROJECTOR_FILENAME) + if file_io.file_exists(config_fpath): + file_content = file_io.read_file_to_string(config_fpath).decode('utf-8') + text_format.Merge(file_content, config) + + has_tensor_files = False + for embedding in config.embeddings: + if embedding.tensor_path: + has_tensor_files = True + break if not config.model_checkpoint_path: # See if you can find a checkpoint file in the logdir. @@ -92,13 +136,15 @@ class ProjectorPlugin(TBPlugin): if not ckpt_path: # Or in the parent of logdir. ckpt_path = latest_checkpoint(os.path.join('../', logdir)) - if not ckpt_path: + if not ckpt_path and not has_tensor_files: logging.warning('Cannot find model checkpoint in %s', logdir) continue - config.model_checkpoint_path = ckpt_path + if ckpt_path: + config.model_checkpoint_path = ckpt_path # Sanity check for the checkpoint file. - if not file_io.file_exists(config.model_checkpoint_path): + if (config.model_checkpoint_path and + not checkpoint_exists(config.model_checkpoint_path)): logging.warning('Checkpoint file %s not found', config.model_checkpoint_path) continue @@ -111,18 +157,20 @@ class ProjectorPlugin(TBPlugin): return self.readers[run] config = self.configs[run] - reader = NewCheckpointReader(config.model_checkpoint_path) + reader = None + if config.model_checkpoint_path: + reader = NewCheckpointReader(config.model_checkpoint_path) self.readers[run] = reader return reader def _get_metadata_file_for_tensor(self, tensor_name, config): - embedding_info = self._get_embedding_info_for_tensor(tensor_name, config) + embedding_info = self._get_embedding(tensor_name, config) if embedding_info: return embedding_info.metadata_path return None def _get_bookmarks_file_for_tensor(self, tensor_name, config): - embedding_info = self._get_embedding_info_for_tensor(tensor_name, config) + embedding_info = self._get_embedding(tensor_name, config) if embedding_info: return embedding_info.bookmarks_path return None @@ -133,7 +181,7 @@ class ProjectorPlugin(TBPlugin): else: return tensor_name - def _get_embedding_info_for_tensor(self, tensor_name, config): + def _get_embedding(self, tensor_name, config): if not config.embeddings: return None for info in config.embeddings: @@ -146,7 +194,7 @@ class ProjectorPlugin(TBPlugin): """Returns a list of runs that have embeddings.""" self.handler.respond(list(self.configs.keys()), 'application/json') - def _serve_info(self, query_params): + def _serve_config(self, query_params): run = query_params.get('run') if run is None: self.handler.respond('query parameter "run" is required', @@ -157,19 +205,6 @@ class ProjectorPlugin(TBPlugin): return config = self.configs[run] - reader = self._get_reader_for_run(run) - var_map = reader.get_variable_to_shape_map() - - for tensor_name, tensor_shape in var_map.items(): - if len(tensor_shape) != 2: - continue - info = self._get_embedding_info_for_tensor(tensor_name, config) - if not info: - info = config.embeddings.add() - info.tensor_name = tensor_name - if not info.tensor_shape: - info.tensor_shape.extend(tensor_shape) - self.handler.respond(json_format.MessageToJson(config), 'application/json') def _serve_metadata(self, query_params): @@ -192,7 +227,7 @@ class ProjectorPlugin(TBPlugin): fpath = self._get_metadata_file_for_tensor(name, config) if not fpath: self.handler.respond( - 'Not metadata file found for tensor %s in the config file %s' % + 'No metadata file found for tensor %s in the config file %s' % (name, self.config_fpaths[run]), 'text/plain', 400) return if not file_io.file_exists(fpath) or file_io.is_directory(fpath): @@ -231,12 +266,27 @@ class ProjectorPlugin(TBPlugin): reader = self._get_reader_for_run(run) config = self.configs[run] - if not reader.has_tensor(name): - self.handler.respond('Tensor %s not found in checkpoint dir %s' % - (name, config.model_checkpoint_path), - 'text/plain', 400) - return - tensor = reader.get_tensor(name) + + if reader is None: + # See if there is a tensor file in the config. + embedding = self._get_embedding(name, config) + if not embedding or not embedding.tensor_path: + self.handler.respond('Tensor %s has no tensor_path in the config' % + name, 'text/plain', 400) + return + if not file_io.file_exists(embedding.tensor_path): + self.handler.respond('Tensor file %s does not exist' % + embedding.tensor_path, 'text/plain', 400) + return + tensor = _read_tensor_file(embedding.tensor_path) + else: + if not reader.has_tensor(name): + self.handler.respond('Tensor %s not found in checkpoint dir %s' % + (name, config.model_checkpoint_path), + 'text/plain', 400) + return + tensor = reader.get_tensor(name) + # Sample the tensor tensor = tensor[:LIMIT_NUM_POINTS] # Stream it as TSV. @@ -294,7 +344,7 @@ class ProjectorPlugin(TBPlugin): return config = self.configs[run] - embedding_info = self._get_embedding_info_for_tensor(name, config) + embedding_info = self._get_embedding(name, config) if not embedding_info or not embedding_info.sprite.image_path: self.handler.respond( -- GitLab