diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index e001d866c3550af109eb41e7f0e625d10319e296..193d14e1ce89c31d78370d944f76d85364bbeebb 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -360,6 +360,12 @@ def _call_model_fn(model_fn, features, labels, mode, config, params, """Calls the model_fn with required parameters.""" model_fn_args = util.fn_args(model_fn) kwargs = {} + if 'labels' in model_fn_args: + kwargs['labels'] = labels + else: + if labels is not None: + raise ValueError( + 'model_fn does not take labels, but input_fn returns labels.') if 'mode' in model_fn_args: kwargs['mode'] = mode if 'config' in model_fn_args: @@ -371,7 +377,7 @@ def _call_model_fn(model_fn, features, labels, mode, config, params, 'model_fn ({}) does not include params argument, ' 'required by TPUEstimator to pass batch size as ' 'params[\'batch_size\']'.format(model_fn)) - return model_fn(features=features, labels=labels, **kwargs) + return model_fn(features=features, **kwargs) def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params):