提交 a9e87682 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Support weighted sparse column in sdca optimizer.

Change: 127093435
上级 cd2162b3
......@@ -203,6 +203,33 @@ class LinearClassifierTest(tf.test.TestCase):
scores = classifier.evaluate(input_fn=input_fn, steps=2)
self.assertGreater(scores['accuracy'], 0.9)
def testSdcaOptimizerWeightedSparseFeatures(self):
"""LinearClasssifier with SDCAOptimizer and weighted sparse features."""
def input_fn():
return {
'example_id': tf.constant(['1', '2', '3']),
'price': tf.SparseTensor(values=[2., 3., 1.],
indices=[[0, 0], [1, 0], [2, 0]],
shape=[3, 5]),
'country': tf.SparseTensor(values=['IT', 'US', 'GB'],
indices=[[0, 0], [1, 0], [2, 0]],
shape=[3, 5])
}, tf.constant([[1], [0], [1]])
country = tf.contrib.layers.sparse_column_with_hash_bucket(
'country', hash_bucket_size=5)
country_weighted_by_price = tf.contrib.layers.weighted_sparse_column(
country, 'price')
sdca_optimizer = tf.contrib.learn.SDCAOptimizer(
example_id_column='example_id')
classifier = tf.contrib.learn.LinearClassifier(
feature_columns=[country_weighted_by_price],
optimizer=sdca_optimizer)
classifier.fit(input_fn=input_fn, steps=50)
scores = classifier.evaluate(input_fn=input_fn, steps=2)
self.assertGreater(scores['accuracy'], 0.9)
def testSdcaOptimizerCrossedFeatures(self):
"""Tests LinearClasssifier with SDCAOptimizer and crossed features."""
......
......@@ -124,6 +124,17 @@ class SDCAOptimizer(object):
column.length)
sparse_features.append(math_ops.to_float(sparse_features_tensor))
sparse_features_weights.append(columns_to_variables[column][0])
elif isinstance(
column,
layers.feature_column._WeightedSparseColumn): # pylint: disable=protected-access
id_tensor = column.id_tensor(transformed_tensor)
weight_tensor = column.weight_tensor(transformed_tensor)
sparse_features_tensor = sparse_ops.sparse_merge(
id_tensor, weight_tensor, column.length,
name="{}_sparse_merge".format(column.name))
sparse_features.append(math_ops.to_float(
sparse_features_tensor, name="{}_to_float".format(column.name)))
sparse_features_weights.append(columns_to_variables[column][0])
else:
raise ValueError("SDCAOptimizer does not support column type %s." %
type(column).__name__)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册