diff --git a/research/object_detection/model_lib.py b/research/object_detection/model_lib.py index cd58607f67e1a0c92b2cf6a293f7aa0bfd82d967..525bb34334f351125a8252ad5898f34380d2c67e 100644 --- a/research/object_detection/model_lib.py +++ b/research/object_detection/model_lib.py @@ -22,8 +22,6 @@ import functools import tensorflow as tf -from tensorflow.contrib.tpu.python.tpu import tpu_estimator -from tensorflow.contrib.tpu.python.tpu import tpu_optimizer from object_detection import eval_util from object_detection import inputs from object_detection.builders import model_builder @@ -291,7 +289,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False): if mode == tf.estimator.ModeKeys.TRAIN: if use_tpu: - training_optimizer = tpu_optimizer.CrossShardOptimizer( + training_optimizer = tf.contrib.tpu.CrossShardOptimizer( training_optimizer) # Optionally freeze some layers by setting their gradients to be zero. @@ -490,7 +488,7 @@ def create_estimator_and_inputs(run_config, model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu) if use_tpu_estimator: - estimator = tpu_estimator.TPUEstimator( + estimator = tf.contrib.tpu.TPUEstimator( model_fn=model_fn, train_batch_size=train_config.batch_size, # For each core, only batch size 1 is supported for eval. diff --git a/research/object_detection/protos/train.proto b/research/object_detection/protos/train.proto index e8bf871b9420320ba43aa119df4db4e491c1c7f9..4776b344bdf4e9a3c1a24db0c78a5dc7915bb150 100644 --- a/research/object_detection/protos/train.proto +++ b/research/object_detection/protos/train.proto @@ -7,7 +7,9 @@ import "object_detection/protos/preprocessor.proto"; // Message for configuring DetectionModel training jobs (train.py). message TrainConfig { - // Input queue batch size. + // Effective batch size to use for training. + // For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be + // `batch_size` / number of cores (or `batch_size` / number of GPUs). optional uint32 batch_size = 1 [default=32]; // Data augmentation options.