提交 41d6b889 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fix the constructors of DNNClassifier and DNNREgressor.

Change: 124003810
上级 e1d23bbb
......@@ -94,14 +94,17 @@ class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
weight_column_name=None,
optimizer=None,
activation_fn=nn.relu,
dropout=None):
super(DNNClassifier, self).__init__(n_classes=n_classes,
dropout=None,
config=None):
super(DNNClassifier, self).__init__(model_dir=model_dir,
n_classes=n_classes,
weight_column_name=weight_column_name,
dnn_feature_columns=feature_columns,
dnn_optimizer=optimizer,
dnn_hidden_units=hidden_units,
dnn_activation_fn=activation_fn,
dnn_dropout=dropout)
dnn_dropout=dropout,
config=config)
def _get_train_ops(self, features, targets):
"""See base class."""
......@@ -185,13 +188,16 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
weight_column_name=None,
optimizer=None,
activation_fn=nn.relu,
dropout=None):
super(DNNRegressor, self).__init__(weight_column_name=weight_column_name,
dropout=None,
config=None):
super(DNNRegressor, self).__init__(model_dir=model_dir,
weight_column_name=weight_column_name,
dnn_feature_columns=feature_columns,
dnn_optimizer=optimizer,
dnn_hidden_units=hidden_units,
dnn_activation_fn=activation_fn,
dnn_dropout=dropout)
dnn_dropout=dropout,
config=config)
def _get_train_ops(self, features, targets):
"""See base class."""
......
......@@ -80,13 +80,15 @@ class LinearClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
model_dir=None,
n_classes=2,
weight_column_name=None,
optimizer=None):
optimizer=None,
config=None):
super(LinearClassifier, self).__init__(
model_dir=model_dir,
n_classes=n_classes,
weight_column_name=weight_column_name,
linear_feature_columns=feature_columns,
linear_optimizer=optimizer)
linear_optimizer=optimizer,
config=config)
def _get_train_ops(self, features, targets):
"""See base class."""
......@@ -156,12 +158,14 @@ class LinearRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
model_dir=None,
n_classes=2,
weight_column_name=None,
optimizer=None):
optimizer=None,
config=None):
super(LinearRegressor, self).__init__(
model_dir=model_dir,
weight_column_name=weight_column_name,
linear_feature_columns=feature_columns,
linear_optimizer=optimizer)
linear_optimizer=optimizer,
config=config)
def _get_train_ops(self, features, targets):
"""See base class."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册