提交 35a4183c 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add an arg to fix the global_step increment bug for DNNLinearCombined estimators.

Change: 149956102
上级 7160a7e0
......@@ -66,7 +66,8 @@ def optimize_loss(loss,
variables=None,
name=None,
summaries=None,
colocate_gradients_with_ops=False):
colocate_gradients_with_ops=False,
increment_global_step=True):
"""Given loss and parameters for optimizer, returns a training op.
Various ways of passing optimizers, include:
......@@ -87,9 +88,10 @@ def optimize_loss(loss,
Args:
loss: Scalar `Tensor`.
global_step: Scalar int `Tensor`, step counter for each update. If not
supplied, it will be fetched from the default graph (see
`tf.contrib.framework.get_global_step` for details). If it's
global_step: Scalar int `Tensor`, step counter to update on each step
unless `increment_global_step` is `False`. If not supplied,
it will be fetched from the default graph (see
`tf.train.get_global_step` for details). If it's
not been created, no step will be incremented with each weight
update. `learning_rate_decay_fn` requires `global_step`.
learning_rate: float or `Tensor`, magnitude of update per each training
......@@ -129,6 +131,10 @@ def optimize_loss(loss,
complete list is in OPTIMIZER_SUMMARIES.
colocate_gradients_with_ops: If True, try colocating gradients with the
corresponding op.
increment_global_step: Whether to increment `global_step`. If your model
calls `optimize_loss` multiple times per training step (e.g. to optimize
different parts of the model), use this arg to avoid incrementing
`global_step` more times than necessary.
Returns:
Training op.
......@@ -277,7 +283,9 @@ def optimize_loss(loss,
# Create gradient updates.
grad_updates = opt.apply_gradients(
gradients, global_step=global_step, name="train")
gradients,
global_step=global_step if increment_global_step else None,
name="train")
# Ensure the train_tensor computes grad_updates.
train_tensor = control_flow_ops.with_dependencies([grad_updates], loss)
......
......@@ -358,6 +358,30 @@ class OptimizersTest(test.TestCase):
self.assertEqual(20, update_var.eval())
self.assertEqual(1, global_step.eval())
def testUpdateOpNoIncrementGlobalStep(self):
optimizers = [
"SGD", gradient_descent.GradientDescentOptimizer,
gradient_descent.GradientDescentOptimizer(learning_rate=0.1)
]
for optimizer in optimizers:
with ops.Graph().as_default() as g, self.test_session(graph=g) as session:
x, var, loss, global_step = _setup_model()
update_var = variable_scope.get_variable(
"update", [], initializer=init_ops.constant_initializer(10))
update_op = state_ops.assign(update_var, 20)
train = optimizers_lib.optimize_loss(
loss,
global_step,
learning_rate=0.1,
optimizer=optimizer,
update_ops=[update_op],
increment_global_step=False)
variables.global_variables_initializer().run()
session.run(train, feed_dict={x: 5})
self.assertEqual(9.5, var.eval())
self.assertEqual(20, update_var.eval())
self.assertEqual(0, global_step.eval())
def testUpdateOpWithNoOpDecay(self):
optimizers = [
"SGD", gradient_descent.GradientDescentOptimizer,
......
......@@ -25,7 +25,6 @@ import six
from tensorflow.contrib import layers
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn import metric_spec
......@@ -39,7 +38,9 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import training_util
# The default learning rates are a historical artifact of the initial
......@@ -48,6 +49,12 @@ _DNN_LEARNING_RATE = 0.05
_LINEAR_LEARNING_RATE = 0.2
_FIX_GLOBAL_STEP_INCREMENT_DATE = "2017-04-15"
_FIX_GLOBAL_STEP_INCREMENT_INSTRUCTIONS = (
"Please set fix_global_step_increment_bug=True and update training steps "
"in your pipeline. See pydoc for details.")
def _as_iterable(preds, output):
for pred in preds:
yield pred[output]
......@@ -174,6 +181,8 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
params.get("input_layer_min_slice_size") or 64 << 20)
num_ps_replicas = config.num_ps_replicas if config else 0
embedding_lr_multipliers = params.get("embedding_lr_multipliers", {})
fix_global_step_increment_bug = params.get(
"fix_global_step_increment_bug", True)
if not linear_feature_columns and not dnn_feature_columns:
raise ValueError(
......@@ -279,11 +288,12 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
def _make_training_op(training_loss):
"""Training op for the DNN linear combined model."""
train_ops = []
global_step = training_util.get_global_step()
if dnn_logits is not None:
train_ops.append(
optimizers.optimize_loss(
loss=training_loss,
global_step=contrib_variables.get_global_step(),
global_step=global_step,
learning_rate=_DNN_LEARNING_RATE,
optimizer=_get_optimizer(dnn_optimizer),
gradient_multipliers=_extract_embedding_lr_multipliers( # pylint: disable=protected-access
......@@ -293,21 +303,28 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
variables=ops.get_collection(dnn_parent_scope),
name=dnn_parent_scope,
# Empty summaries, because head already logs "loss" summary.
summaries=[]))
summaries=[],
increment_global_step=not fix_global_step_increment_bug))
if linear_logits is not None:
train_ops.append(
optimizers.optimize_loss(
loss=training_loss,
global_step=contrib_variables.get_global_step(),
global_step=global_step,
learning_rate=_linear_learning_rate(len(linear_feature_columns)),
optimizer=_get_optimizer(linear_optimizer),
clip_gradients=gradient_clip_norm,
variables=ops.get_collection(linear_parent_scope),
name=linear_parent_scope,
# Empty summaries, because head already logs "loss" summary.
summaries=[]))
summaries=[],
increment_global_step=not fix_global_step_increment_bug))
return control_flow_ops.group(*train_ops)
train_op = control_flow_ops.group(*train_ops)
if fix_global_step_increment_bug:
with ops.control_dependencies([train_op]):
with ops.colocate_with(global_step):
return state_ops.assign_add(global_step, 1).op
return train_op
return head.create_model_fn_ops(
features=features,
......@@ -320,20 +337,27 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
class DNNLinearCombinedEstimator(estimator.Estimator):
"""An estimator for TensorFlow Linear and DNN joined training models.
Input of `fit`, `train`, and `evaluate` should have following features,
otherwise there will be a `KeyError`:
if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
for each `column` in `dnn_feature_columns` + `linear_feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `WeightedSparseColumn`, two features: the first with
`key` the id column name, the second with `key` the weight column
name. Both features' `value` must be a `SparseTensor`.
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
whose `value` is a `Tensor`.
Note: New users must set `fix_global_step_increment_bug=True` when creating an
estimator.
Input of `fit`, `train`, and `evaluate` should have following features,
otherwise there will be a `KeyError`:
if `weight_column_name` is not `None`, a feature with
`key=weight_column_name` whose value is a `Tensor`.
for each `column` in `dnn_feature_columns` + `linear_feature_columns`:
- if `column` is a `SparseColumn`, a feature with `key=column.name`
whose `value` is a `SparseTensor`.
- if `column` is a `WeightedSparseColumn`, two features: the first with
`key` the id column name, the second with `key` the weight column
name. Both features' `value` must be a `SparseTensor`.
- if `column` is a `RealValuedColumn, a feature with `key=column.name`
whose `value` is a `Tensor`.
"""
@deprecated_arg_values(
_FIX_GLOBAL_STEP_INCREMENT_DATE,
_FIX_GLOBAL_STEP_INCREMENT_INSTRUCTIONS,
fix_global_step_increment_bug=False)
def __init__(self, # _joint_linear_weights pylint: disable=invalid-name
head,
model_dir=None,
......@@ -348,9 +372,13 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
gradient_clip_norm=None,
config=None,
feature_engineering_fn=None,
embedding_lr_multipliers=None):
embedding_lr_multipliers=None,
fix_global_step_increment_bug=False):
"""Initializes a DNNLinearCombinedEstimator instance.
Note: New users must set `fix_global_step_increment_bug=True` when creating
an estimator.
Args:
head: A _Head object.
model_dir: Directory to save model parameters, graph and etc. This can
......@@ -381,12 +409,15 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
tf.clip_by_global_norm for more details.
config: RunConfig object to configure the runtime settings.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and
returns features and labels which will be fed
into the model.
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
embedding_lr_multipliers: Optional. A dictionary from `EmbeddingColumn` to
a `float` multiplier. Multiplier will be used to multiply with
learning rate for the embedding variables.
a `float` multiplier. Multiplier will be used to multiply with
learning rate for the embedding variables.
fix_global_step_increment_bug: If `False`, the estimator needs two fit
steps to optimize both linear and dnn parts. If `True`, this bug is
fixed. New users must set this to `True`, but the default value is
`False` for backwards compatibility.
Raises:
ValueError: If both linear_feature_columns and dnn_features_columns are
......@@ -413,6 +444,7 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
"dnn_dropout": dnn_dropout,
"gradient_clip_norm": gradient_clip_norm,
"embedding_lr_multipliers": embedding_lr_multipliers,
"fix_global_step_increment_bug": fix_global_step_increment_bug,
},
feature_engineering_fn=feature_engineering_fn)
......@@ -420,6 +452,9 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
class DNNLinearCombinedClassifier(estimator.Estimator):
"""A classifier for TensorFlow Linear and DNN joined training models.
Note: New users must set `fix_global_step_increment_bug=True` when creating an
estimator.
Example:
```python
......@@ -469,6 +504,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
whose `value` is a `Tensor`.
"""
@deprecated_arg_values(
_FIX_GLOBAL_STEP_INCREMENT_DATE,
_FIX_GLOBAL_STEP_INCREMENT_INSTRUCTIONS,
fix_global_step_increment_bug=False)
def __init__(self, # _joint_linear_weights pylint: disable=invalid-name
model_dir=None,
n_classes=2,
......@@ -486,9 +525,13 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
config=None,
feature_engineering_fn=None,
embedding_lr_multipliers=None,
input_layer_min_slice_size=None):
input_layer_min_slice_size=None,
fix_global_step_increment_bug=False):
"""Constructs a DNNLinearCombinedClassifier instance.
Note: New users must set `fix_global_step_increment_bug=True` when creating
an estimator.
Args:
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
......@@ -527,14 +570,17 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
residual after centered bias.
config: RunConfig object to configure the runtime settings.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and
returns features and labels which will be fed
into the model.
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
embedding_lr_multipliers: Optional. A dictionary from `EmbeddingColumn` to
a `float` multiplier. Multiplier will be used to multiply with
learning rate for the embedding variables.
a `float` multiplier. Multiplier will be used to multiply with
learning rate for the embedding variables.
input_layer_min_slice_size: Optional. The min slice size of input layer
partitions. If not provided, will use the default of 64M.
partitions. If not provided, will use the default of 64M.
fix_global_step_increment_bug: If `False`, the estimator needs two fit
steps to optimize both linear and dnn parts. If `True`, this bug is
fixed. New users must set this to `True`, but it the default value is
`False` for backwards compatibility.
Raises:
ValueError: If `n_classes` < 2.
......@@ -571,6 +617,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
"gradient_clip_norm": gradient_clip_norm,
"embedding_lr_multipliers": embedding_lr_multipliers,
"input_layer_min_slice_size": input_layer_min_slice_size,
"fix_global_step_increment_bug": fix_global_step_increment_bug,
},
feature_engineering_fn=feature_engineering_fn)
......@@ -709,6 +756,9 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
class DNNLinearCombinedRegressor(estimator.Estimator):
"""A regressor for TensorFlow Linear and DNN joined training models.
Note: New users must set `fix_global_step_increment_bug=True` when creating an
estimator.
Example:
```python
......@@ -764,6 +814,10 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
whose `value` is a `Tensor`.
"""
@deprecated_arg_values(
_FIX_GLOBAL_STEP_INCREMENT_DATE,
_FIX_GLOBAL_STEP_INCREMENT_INSTRUCTIONS,
fix_global_step_increment_bug=False)
def __init__(self, # _joint_linear_weights pylint: disable=invalid-name
model_dir=None,
weight_column_name=None,
......@@ -781,9 +835,13 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
config=None,
feature_engineering_fn=None,
embedding_lr_multipliers=None,
input_layer_min_slice_size=None):
input_layer_min_slice_size=None,
fix_global_step_increment_bug=False):
"""Initializes a DNNLinearCombinedRegressor instance.
Note: New users must set `fix_global_step_increment_bug=True` when creating
an estimator.
Args:
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
......@@ -821,15 +879,17 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
(typically, these have shape `[batch_size, label_dimension]`).
config: RunConfig object to configure the runtime settings.
feature_engineering_fn: Feature engineering function. Takes features and
labels which are the output of `input_fn` and
returns features and labels which will be fed
into the model.
labels which are the output of `input_fn` and returns features and
labels which will be fed into the model.
embedding_lr_multipliers: Optional. A dictionary from `EmbeddingColumn` to
a `float` multiplier. Multiplier will be used to multiply with
learning rate for the embedding variables.
a `float` multiplier. Multiplier will be used to multiply with
learning rate for the embedding variables.
input_layer_min_slice_size: Optional. The min slice size of input layer
partitions. If not provided, will use the default of 64M.
partitions. If not provided, will use the default of 64M.
fix_global_step_increment_bug: If `False`, the estimator needs two fit
steps to optimize both linear and dnn parts. If `True`, this bug is
fixed. New users must set this to `True`, but it the default value is
`False` for backwards compatibility.
Raises:
ValueError: If both linear_feature_columns and dnn_features_columns are
......@@ -862,6 +922,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
"gradient_clip_norm": gradient_clip_norm,
"embedding_lr_multipliers": embedding_lr_multipliers,
"input_layer_min_slice_size": input_layer_min_slice_size,
"fix_global_step_increment_bug": fix_global_step_increment_bug,
},
feature_engineering_fn=feature_engineering_fn)
......
......@@ -24,7 +24,6 @@ import tempfile
import numpy as np
from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.learn.python.learn import experiment
from tensorflow.contrib.learn.python.learn.datasets import base
......@@ -39,6 +38,7 @@ from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec
from tensorflow.contrib.metrics.python.ops import metric_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
......@@ -51,6 +51,8 @@ from tensorflow.python.training import input as input_lib
from tensorflow.python.training import learning_rate_decay
from tensorflow.python.training import monitored_session
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
def _assert_metrics_in_range(keys, metrics):
......@@ -88,6 +90,21 @@ class _CheckCallsHead(head_lib.Head):
return self._head_ops_called_times
class _StepCounterHook(session_run_hook.SessionRunHook):
"""Counts the number of training steps."""
def __init__(self):
self._steps = 0
def after_run(self, run_context, run_values):
del run_context, run_values
self._steps += 1
@property
def steps(self):
return self._steps
class EmbeddingMultiplierTest(test.TestCase):
"""dnn_model_fn tests."""
......@@ -133,38 +150,39 @@ class EmbeddingMultiplierTest(test.TestCase):
'dnn_feature_columns': [embedding_language, embedding_wire],
'head': head_lib.multi_class_head(2),
'dnn_hidden_units': [1],
# Set lr mult to 0. to keep embeddings constant.
# Set lr mult to 0. to keep language embeddings constant, whereas wire
# embeddings will be trained.
'embedding_lr_multipliers': {
embedding_language: 0.0
},
'dnn_optimizer': 'Adagrad',
}
features = {
'language':
sparse_tensor.SparseTensor(
values=['en', 'fr', 'zh'],
indices=[[0, 0], [1, 0], [2, 0]],
dense_shape=[3, 1]),
'wire':
sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'],
indices=[[0, 0], [1, 0], [2, 0]],
dense_shape=[3, 1]),
}
labels = constant_op.constant([[0], [0], [0]], dtype=dtypes.int32)
model_ops = dnn_linear_combined._dnn_linear_combined_model_fn(
features, labels, model_fn.ModeKeys.TRAIN, params)
with monitored_session.MonitoredSession() as sess:
language_var = dnn_linear_combined._get_embedding_variable(
embedding_language, 'dnn', 'dnn/input_from_feature_columns')
wire_var = dnn_linear_combined._get_embedding_variable(
embedding_wire, 'dnn', 'dnn/input_from_feature_columns')
for _ in range(2):
_, language_value, wire_value = sess.run(
[model_ops.train_op, language_var, wire_var])
initial_value = np.full_like(language_value, 0.1)
self.assertTrue(np.all(np.isclose(language_value, initial_value)))
self.assertFalse(np.all(np.isclose(wire_value, initial_value)))
with ops.Graph().as_default():
features = {
'language':
sparse_tensor.SparseTensor(
values=['en', 'fr', 'zh'],
indices=[[0, 0], [1, 0], [2, 0]],
dense_shape=[3, 1]),
'wire':
sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'],
indices=[[0, 0], [1, 0], [2, 0]],
dense_shape=[3, 1]),
}
labels = constant_op.constant([[1], [0], [0]], dtype=dtypes.int32)
training_util.create_global_step()
model_ops = dnn_linear_combined._dnn_linear_combined_model_fn(
features, labels, model_fn.ModeKeys.TRAIN, params)
with monitored_session.MonitoredSession() as sess:
language_var = dnn_linear_combined._get_embedding_variable(
embedding_language, 'dnn', 'dnn/input_from_feature_columns')
language_initial_value = sess.run(language_var)
for _ in range(2):
_, language_value = sess.run([model_ops.train_op, language_var])
self.assertAllClose(language_value, language_initial_value)
# We could also test that wire_value changed, but that test would be flaky.
class DNNLinearCombinedEstimatorTest(test.TestCase):
......@@ -601,7 +619,7 @@ class DNNLinearCombinedClassifierTest(test.TestCase):
]
def _optimizer_exp_decay():
global_step = variables.get_global_step()
global_step = training_util.get_global_step()
learning_rate = learning_rate_decay.exponential_decay(
learning_rate=0.1,
global_step=global_step,
......@@ -736,7 +754,7 @@ class DNNLinearCombinedClassifierTest(test.TestCase):
})
def testVariableQuery(self):
"""Tests bias is centered or not."""
"""Tests get_variable_names and get_variable_value."""
def _input_fn_train():
# Create 4 rows, three (y = x), one (y=Not(x))
......@@ -838,6 +856,99 @@ class DNNLinearCombinedClassifierTest(test.TestCase):
classifier.fit(input_fn=_input_fn_train, steps=500)
self.assertNotIn('centered_bias_weight', classifier.get_variable_names())
def testGlobalStepLinearOnly(self):
"""Tests global step update for linear-only model."""
def input_fn():
return {
'age': constant_op.constant([1]),
'language':
sparse_tensor.SparseTensor(
values=['english'], indices=[[0, 0]], dense_shape=[1, 1])
}, constant_op.constant([[1]])
language = feature_column.sparse_column_with_hash_bucket('language', 10)
age = feature_column.real_valued_column('age')
step_counter = _StepCounterHook()
classifier = dnn_linear_combined.DNNLinearCombinedClassifier(
linear_feature_columns=[age, language])
classifier.fit(input_fn=input_fn, steps=100, monitors=[step_counter])
self.assertEqual(100, step_counter.steps)
def testGlobalStepDNNOnly(self):
"""Tests global step update for dnn-only model."""
def input_fn():
return {
'language':
sparse_tensor.SparseTensor(
values=['english'], indices=[[0, 0]], dense_shape=[1, 1])
}, constant_op.constant([[1]])
language = feature_column.sparse_column_with_hash_bucket('language', 10)
step_counter = _StepCounterHook()
classifier = dnn_linear_combined.DNNLinearCombinedClassifier(
dnn_feature_columns=[
feature_column.embedding_column(language, dimension=1)],
dnn_hidden_units=[3, 3])
classifier.fit(input_fn=input_fn, steps=100, monitors=[step_counter])
self.assertEqual(100, step_counter.steps)
def testGlobalStepDNNLinearCombinedBug(self):
"""Tests global step update for dnn-linear combined model."""
def input_fn():
return {
'age': constant_op.constant([1]),
'language':
sparse_tensor.SparseTensor(
values=['english'], indices=[[0, 0]], dense_shape=[1, 1])
}, constant_op.constant([[1]])
language = feature_column.sparse_column_with_hash_bucket('language', 10)
age = feature_column.real_valued_column('age')
step_counter = _StepCounterHook()
classifier = dnn_linear_combined.DNNLinearCombinedClassifier(
linear_feature_columns=[age, language],
dnn_feature_columns=[
feature_column.embedding_column(language, dimension=1)],
dnn_hidden_units=[3, 3],
fix_global_step_increment_bug=False)
classifier.fit(input_fn=input_fn, steps=100, monitors=[step_counter])
# Expected is 100, but because of the global step increment bug, this is 51.
self.assertEqual(51, step_counter.steps)
def testGlobalStepDNNLinearCombinedBugFixed(self):
"""Tests global step update for dnn-linear combined model."""
def input_fn():
return {
'age': constant_op.constant([1]),
'language':
sparse_tensor.SparseTensor(
values=['english'], indices=[[0, 0]], dense_shape=[1, 1])
}, constant_op.constant([[1]])
language = feature_column.sparse_column_with_hash_bucket('language', 10)
age = feature_column.real_valued_column('age')
step_counter = _StepCounterHook()
classifier = dnn_linear_combined.DNNLinearCombinedClassifier(
linear_feature_columns=[age, language],
dnn_feature_columns=[
feature_column.embedding_column(language, dimension=1)],
dnn_hidden_units=[3, 3],
fix_global_step_increment_bug=True)
classifier.fit(input_fn=input_fn, steps=100, monitors=[step_counter])
self.assertEqual(100, step_counter.steps)
def testLinearOnly(self):
"""Tests that linear-only instantiation works."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册