提交 c21d3c25 编写于 作者: Z Zhichao Lu 提交者: lzc5123016

Allow model.py to be extended with custom model building functions.

PiperOrigin-RevId: 187941168
上级 307f1f77
......@@ -200,8 +200,8 @@ def create_train_input_fn(train_config, train_input_config,
keypoints for each box.
Raises:
TypeError: if the `train_config` or `train_input_config` are not of the
correct type.
TypeError: if the `train_config`, `train_input_config` or `model_config`
are not of the correct type.
"""
if not isinstance(train_config, train_pb2.TrainConfig):
raise TypeError('For training mode, the `train_config` must be a '
......@@ -316,8 +316,8 @@ def create_eval_input_fn(eval_config, eval_input_config, model_config):
which represent instance masks for objects.
Raises:
TypeError: if the `eval_config` or `eval_input_config` are not of the
correct type.
TypeError: if the `eval_config`, `eval_input_config` or `model_config`
are not of the correct type.
"""
del params
if not isinstance(eval_config, eval_pb2.EvalConfig):
......
......@@ -32,6 +32,7 @@ import tensorflow as tf
from google.protobuf import text_format
from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
from tensorflow.python.lib.io import file_io
from object_detection import eval_util
from object_detection import inputs
from object_detection import model_hparams
......@@ -54,6 +55,20 @@ tf.flags.DEFINE_integer('num_eval_steps', 10000, 'Number of train steps.')
FLAGS = tf.flags.FLAGS
# A map of names to methods that help build the model.
MODEL_BUILD_UTIL_MAP = {
'get_configs_from_pipeline_file':
config_util.get_configs_from_pipeline_file,
'create_pipeline_proto_from_configs':
config_util.create_pipeline_proto_from_configs,
'merge_external_params_with_configs':
config_util.merge_external_params_with_configs,
'create_train_input_fn': inputs.create_train_input_fn,
'create_eval_input_fn': inputs.create_eval_input_fn,
'create_predict_input_fn': inputs.create_predict_input_fn,
}
def _get_groundtruth_data(detection_model, class_agnostic):
"""Extracts groundtruth data from detection_model.
......@@ -413,8 +428,18 @@ def populate_experiment(run_config,
An `Experiment` that defines all aspects of training, evaluation, and
export.
"""
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
configs = config_util.merge_external_params_with_configs(
get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
'get_configs_from_pipeline_file']
create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
'create_pipeline_proto_from_configs']
merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
'merge_external_params_with_configs']
create_train_input_fn = MODEL_BUILD_UTIL_MAP['create_train_input_fn']
create_eval_input_fn = MODEL_BUILD_UTIL_MAP['create_eval_input_fn']
create_predict_input_fn = MODEL_BUILD_UTIL_MAP['create_predict_input_fn']
configs = get_configs_from_pipeline_file(pipeline_config_path)
configs = merge_external_params_with_configs(
configs,
hparams,
train_steps=train_steps,
......@@ -436,18 +461,18 @@ def populate_experiment(run_config,
model_builder.build, model_config=model_config)
# Create the input functions for TRAIN/EVAL.
train_input_fn = inputs.create_train_input_fn(
train_input_fn = create_train_input_fn(
train_config=train_config,
train_input_config=train_input_config,
model_config=model_config)
eval_input_fn = inputs.create_eval_input_fn(
eval_input_fn = create_eval_input_fn(
eval_config=eval_config,
eval_input_config=eval_input_config,
model_config=model_config)
export_strategies = [
tf.contrib.learn.utils.saved_model_export_utils.make_export_strategy(
serving_input_fn=inputs.create_predict_input_fn(
serving_input_fn=create_predict_input_fn(
model_config=model_config))
]
......@@ -457,8 +482,10 @@ def populate_experiment(run_config,
if run_config.is_chief:
# Store the final pipeline config for traceability.
pipeline_config_final = config_util.create_pipeline_proto_from_configs(
pipeline_config_final = create_pipeline_proto_from_configs(
configs)
if not file_io.file_exists(estimator.model_dir):
file_io.recursive_create_dir(estimator.model_dir)
pipeline_config_final_path = os.path.join(estimator.model_dir,
'pipeline.config')
config_text = text_format.MessageToString(pipeline_config_final)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册