提交 53604916 编写于 作者: J Jianwei Xie 提交者: TensorFlower Gardener

Fixed the missing labels test in TPUEstimator.

PiperOrigin-RevId: 161131282
上级 7d5c74a9
......@@ -360,6 +360,12 @@ def _call_model_fn(model_fn, features, labels, mode, config, params,
"""Calls the model_fn with required parameters."""
model_fn_args = util.fn_args(model_fn)
kwargs = {}
if 'labels' in model_fn_args:
kwargs['labels'] = labels
else:
if labels is not None:
raise ValueError(
'model_fn does not take labels, but input_fn returns labels.')
if 'mode' in model_fn_args:
kwargs['mode'] = mode
if 'config' in model_fn_args:
......@@ -371,7 +377,7 @@ def _call_model_fn(model_fn, features, labels, mode, config, params,
'model_fn ({}) does not include params argument, '
'required by TPUEstimator to pass batch size as '
'params[\'batch_size\']'.format(model_fn))
return model_fn(features=features, labels=labels, **kwargs)
return model_fn(features=features, **kwargs)
def _call_model_fn_with_tpu(model_fn, features, labels, mode, config, params):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册