提交 7e810001 编写于 作者: Z Zhichao Lu 提交者: pkulzc

Access TPUEstimator and CrossShardOptimizer from tf namesspace.

PiperOrigin-RevId: 192226678
上级 b0c5c3b5
......@@ -22,8 +22,6 @@ import functools
import tensorflow as tf
from tensorflow.contrib.tpu.python.tpu import tpu_estimator
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
from object_detection import eval_util
from object_detection import inputs
from object_detection.builders import model_builder
......@@ -291,7 +289,7 @@ def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False):
if mode == tf.estimator.ModeKeys.TRAIN:
if use_tpu:
training_optimizer = tpu_optimizer.CrossShardOptimizer(
training_optimizer = tf.contrib.tpu.CrossShardOptimizer(
training_optimizer)
# Optionally freeze some layers by setting their gradients to be zero.
......@@ -490,7 +488,7 @@ def create_estimator_and_inputs(run_config,
model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu)
if use_tpu_estimator:
estimator = tpu_estimator.TPUEstimator(
estimator = tf.contrib.tpu.TPUEstimator(
model_fn=model_fn,
train_batch_size=train_config.batch_size,
# For each core, only batch size 1 is supported for eval.
......
......@@ -7,7 +7,9 @@ import "object_detection/protos/preprocessor.proto";
// Message for configuring DetectionModel training jobs (train.py).
message TrainConfig {
// Input queue batch size.
// Effective batch size to use for training.
// For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
// `batch_size` / number of cores (or `batch_size` / number of GPUs).
optional uint32 batch_size = 1 [default=32];
// Data augmentation options.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册