提交 d51eb956 编写于 作者: I Illia Polosukhin 提交者: TensorFlower Gardener

Support for passing params into Estimator that will be passed to the function...

Support for passing params into Estimator that will be passed to the function call of model_fn. Note, params are only passed if model_fn takes 4 arguments, otherwise it works the same way.
Change: 123985980
上级 dfb532cb
......@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import abc
import inspect
import os
import tempfile
import time
......@@ -74,6 +75,19 @@ def _get_predict_input_fn(x, y, batch_size):
return df.input_builder, df.get_feed_dict_fn()
def _get_arguments(func):
"""Returns list of arguments this function has."""
if hasattr(func, '__code__'):
# Regular function.
return inspect.getargspec(func).args
elif hasattr(func, '__call__'):
# Callable object.
return _get_arguments(func.__call__)
elif hasattr(func, 'func'):
# Partial function.
return _get_arguments(func.func)
class BaseEstimator(sklearn.BaseEstimator):
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
......@@ -589,17 +603,58 @@ class Estimator(BaseEstimator):
Parameters:
model_fn: Model function, takes features and targets tensors or dicts of
tensors and returns predictions and loss tensors.
E.g. `(features, targets) -> (predictions, loss, train_op)`.
Supports next three signatures for the function:
* `(features, targets) -> (predictions, loss, train_op)`
* `(features, targets, mode) -> (predictions, loss, train_op)`
* `(features, targets, mode, params) ->
(predictions, loss, train_op)`
Where:
* `features` are single `Tensor` or `dict` of `Tensor`s
(depending on data passed to `fit`),
* `targets` are `Tensor` or
`dict` of `Tensor`s (for multi-head model).
* `mode` represents if this training, evaluation or prediction.
See `ModeKeys` for example keys.
* `params` is a `dict` of hyperparameters. Will receive what is
passed to Estimator in `params` parameter. This allows to
configure Estimators from hyper parameter tunning.
model_dir: Directory to save model parameters, graph and etc.
config: Configuration object.
params: `dict` of hyper parameters that will be passed into `model_fn`.
Keys are names of parameters, values are basic python types.
"""
def __init__(self,
model_fn=None,
model_dir=None,
config=None):
config=None,
params=None):
super(Estimator, self).__init__(model_dir=model_dir, config=config)
if model_fn is not None:
# Check number of arguments of the given function matches requirements.
model_fn_args = _get_arguments(model_fn)
if params is not None and 'params' not in model_fn_args:
raise ValueError('Estimator\'s model_fn (%s) has less then 4 '
'arguments, but not None params (%s) are passed.' %
(model_fn, params))
if params is None and 'params' in model_fn_args:
logging.warning('Estimator\'s model_fn (%s) has includes params '
'argument, but params are not passed to Estimator.' %
model_fn)
self._model_fn = model_fn
self.params = params
def _call_model_fn(self, features, targets, mode):
"""Calls model function with support of 2, 3 or 4 arguments."""
model_fn_args = _get_arguments(self._model_fn)
if 'mode' in model_fn_args:
if 'params' in model_fn_args:
return self._model_fn(
features, targets, mode=mode, params=self.params)
else:
return self._model_fn(
features, targets, mode=mode)
return self._model_fn(features, targets)
def _get_train_ops(self, features, targets):
"""Method that builds model graph and returns trainer ops.
......@@ -615,7 +670,7 @@ class Estimator(BaseEstimator):
Returns:
Tuple of train `Operation` and loss `Tensor`.
"""
_, loss, train_op = self._model_fn(features, targets, ModeKeys.TRAIN)
_, loss, train_op = self._call_model_fn(features, targets, ModeKeys.TRAIN)
return train_op, loss
def _get_eval_ops(self, features, targets, metrics):
......@@ -633,7 +688,7 @@ class Estimator(BaseEstimator):
Returns:
metrics: `dict` of `Tensor` objects.
"""
predictions, loss, _ = self._model_fn(features, targets, ModeKeys.EVAL)
predictions, loss, _ = self._call_model_fn(features, targets, ModeKeys.EVAL)
result = {'loss': loss}
metrics = metrics or {}
if isinstance(targets, dict) and len(targets) == 1:
......@@ -663,5 +718,5 @@ class Estimator(BaseEstimator):
"""
targets = tensor_signature.create_placeholders_from_signatures(
self._targets_info)
predictions, _, _ = self._model_fn(features, targets, ModeKeys.INFER)
predictions, _, _ = self._call_model_fn(features, targets, ModeKeys.INFER)
return predictions
......@@ -61,7 +61,19 @@ def boston_eval_fn():
return tf.concat(0, [features, features]), tf.concat(0, [target, target])
def linear_model_fn(features, target, unused_mode):
def linear_model_params_fn(features, target, mode, params):
assert mode in ('train', 'eval', 'infer')
prediction, loss = (
tf.contrib.learn.models.linear_regression_zero_init(features, target)
)
train_op = tf.contrib.layers.optimize_loss(
loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad',
learning_rate=params['learning_rate'])
return prediction, loss, train_op
def linear_model_fn(features, target, mode):
assert mode in ('train', 'eval', 'infer')
prediction, loss = (
tf.contrib.learn.models.linear_regression_zero_init(features, target)
)
......@@ -71,7 +83,7 @@ def linear_model_fn(features, target, unused_mode):
return prediction, loss, train_op
def logistic_model_fn(features, target, unused_mode):
def logistic_model_no_mode_fn(features, target):
target = tf.one_hot(target, 3, 1, 0)
prediction, loss = (
tf.contrib.learn.models.logistic_regression_zero_init(features, target)
......@@ -146,6 +158,12 @@ class EstimatorTest(tf.test.TestCase):
metrics={'MSE': tf.contrib.metrics.streaming_mean_squared_error})
self.assertLess(scores3['MSE'], scores['MSE'])
def testEstimatorParams(self):
boston = tf.contrib.learn.datasets.load_boston()
est = tf.contrib.learn.Estimator(model_fn=linear_model_params_fn,
params={'learning_rate': 0.01})
est.fit(x=boston.data, y=boston.target.astype(np.float32), steps=100)
def testBostonAll(self):
boston = tf.contrib.learn.datasets.load_boston()
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn)
......@@ -160,7 +178,7 @@ class EstimatorTest(tf.test.TestCase):
def testIrisAll(self):
iris = tf.contrib.learn.datasets.load_iris()
est = tf.contrib.learn.Estimator(model_fn=logistic_model_fn)
est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn)
est.fit(iris.data, iris.target, steps=100)
scores = est.evaluate(
x=iris.data,
......@@ -177,7 +195,7 @@ class EstimatorTest(tf.test.TestCase):
def testIrisInputFn(self):
iris = tf.contrib.learn.datasets.load_iris()
est = tf.contrib.learn.Estimator(model_fn=logistic_model_fn)
est = tf.contrib.learn.Estimator(model_fn=logistic_model_no_mode_fn)
est.fit(input_fn=iris_input_fn, steps=100)
_ = est.evaluate(input_fn=iris_input_fn, steps=1)
predictions = est.predict(x=iris.data)['class']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册