Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
29293fb6
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
29293fb6
编写于
3月 17, 2017
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
3月 17, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support deepcopy in _SparseColumn.
Change: 150488705
上级
2cc1e156
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
78 addition
and
80 deletion
+78
-80
tensorflow/contrib/layers/python/layers/feature_column.py
tensorflow/contrib/layers/python/layers/feature_column.py
+29
-80
tensorflow/contrib/layers/python/layers/feature_column_test.py
...rflow/contrib/layers/python/layers/feature_column_test.py
+49
-0
未找到文件。
tensorflow/contrib/layers/python/layers/feature_column.py
浏览文件 @
29293fb6
...
@@ -329,6 +329,9 @@ class _SparseColumn(_FeatureColumn,
...
@@ -329,6 +329,9 @@ class _SparseColumn(_FeatureColumn,
if
is_integerized
and
not
dtype
.
is_integer
:
if
is_integerized
and
not
dtype
.
is_integer
:
raise
ValueError
(
"dtype must be an integer if is_integerized is True. "
raise
ValueError
(
"dtype must be an integer if is_integerized is True. "
"dtype: {}, column_name: {}."
.
format
(
dtype
,
column_name
))
"dtype: {}, column_name: {}."
.
format
(
dtype
,
column_name
))
if
dtype
!=
dtypes
.
string
and
not
dtype
.
is_integer
:
raise
ValueError
(
"dtype must be string or integer. "
"dtype: {}, column_name: {}"
.
format
(
dtype
,
column_name
))
if
bucket_size
is
None
and
lookup_config
is
None
:
if
bucket_size
is
None
and
lookup_config
is
None
:
raise
ValueError
(
"one of bucket_size or lookup_config must be set. "
raise
ValueError
(
"one of bucket_size or lookup_config must be set. "
...
@@ -355,9 +358,14 @@ class _SparseColumn(_FeatureColumn,
...
@@ -355,9 +358,14 @@ class _SparseColumn(_FeatureColumn,
raise
ValueError
(
"vocab_size must be defined. "
raise
ValueError
(
"vocab_size must be defined. "
"column_name: {}"
.
format
(
column_name
))
"column_name: {}"
.
format
(
column_name
))
return
super
(
_SparseColumn
,
cls
).
__new__
(
cls
,
column_name
,
is_integerized
,
return
super
(
_SparseColumn
,
cls
).
__new__
(
bucket_size
,
lookup_config
,
cls
,
combiner
,
dtype
)
column_name
,
is_integerized
=
is_integerized
,
bucket_size
=
bucket_size
,
lookup_config
=
lookup_config
,
combiner
=
combiner
,
dtype
=
dtype
)
@
property
@
property
def
name
(
self
):
def
name
(
self
):
...
@@ -440,20 +448,6 @@ class _SparseColumn(_FeatureColumn,
...
@@ -440,20 +448,6 @@ class _SparseColumn(_FeatureColumn,
class
_SparseColumnIntegerized
(
_SparseColumn
):
class
_SparseColumnIntegerized
(
_SparseColumn
):
"""See `sparse_column_with_integerized_feature`."""
"""See `sparse_column_with_integerized_feature`."""
def
__new__
(
cls
,
column_name
,
bucket_size
,
combiner
=
"sqrtn"
,
dtype
=
dtypes
.
int64
):
if
not
dtype
.
is_integer
:
raise
ValueError
(
"dtype must be an integer. "
"dtype: {}, column_name: {}"
.
format
(
dtype
,
column_name
))
return
super
(
_SparseColumnIntegerized
,
cls
).
__new__
(
cls
,
column_name
,
is_integerized
=
True
,
bucket_size
=
bucket_size
,
combiner
=
combiner
,
dtype
=
dtype
)
def
insert_transformed_feature
(
self
,
columns_to_tensors
):
def
insert_transformed_feature
(
self
,
columns_to_tensors
):
"""Handles sparse column to id conversion."""
"""Handles sparse column to id conversion."""
input_tensor
=
self
.
_get_input_sparse_tensor
(
columns_to_tensors
)
input_tensor
=
self
.
_get_input_sparse_tensor
(
columns_to_tensors
)
...
@@ -505,29 +499,13 @@ def sparse_column_with_integerized_feature(column_name,
...
@@ -505,29 +499,13 @@ def sparse_column_with_integerized_feature(column_name,
ValueError: dtype is not integer.
ValueError: dtype is not integer.
"""
"""
return
_SparseColumnIntegerized
(
return
_SparseColumnIntegerized
(
column_name
,
bucket_size
,
combiner
=
combiner
,
dtype
=
dtype
)
column_name
,
is_integerized
=
True
,
bucket_size
=
bucket_size
,
combiner
=
combiner
,
dtype
=
dtype
)
class
_SparseColumnHashed
(
_SparseColumn
):
class
_SparseColumnHashed
(
_SparseColumn
):
"""See `sparse_column_with_hash_bucket`."""
"""See `sparse_column_with_hash_bucket`."""
def
__new__
(
cls
,
column_name
,
hash_bucket_size
,
combiner
=
"sum"
,
dtype
=
dtypes
.
string
):
if
dtype
!=
dtypes
.
string
and
not
dtype
.
is_integer
:
raise
ValueError
(
"dtype must be string or integer. "
"dtype: {}, column_name: {}"
.
format
(
dtype
,
column_name
))
return
super
(
_SparseColumnHashed
,
cls
).
__new__
(
cls
,
column_name
,
bucket_size
=
hash_bucket_size
,
combiner
=
combiner
,
dtype
=
dtype
)
def
insert_transformed_feature
(
self
,
columns_to_tensors
):
def
insert_transformed_feature
(
self
,
columns_to_tensors
):
"""Handles sparse column to id conversion."""
"""Handles sparse column to id conversion."""
input_tensor
=
self
.
_get_input_sparse_tensor
(
columns_to_tensors
)
input_tensor
=
self
.
_get_input_sparse_tensor
(
columns_to_tensors
)
...
@@ -573,26 +551,16 @@ def sparse_column_with_hash_bucket(column_name,
...
@@ -573,26 +551,16 @@ def sparse_column_with_hash_bucket(column_name,
ValueError: hash_bucket_size is not greater than 2.
ValueError: hash_bucket_size is not greater than 2.
ValueError: dtype is neither string nor integer.
ValueError: dtype is neither string nor integer.
"""
"""
return
_SparseColumnHashed
(
column_name
,
hash_bucket_size
,
combiner
,
dtype
)
return
_SparseColumnHashed
(
column_name
,
bucket_size
=
hash_bucket_size
,
combiner
=
combiner
,
dtype
=
dtype
)
class
_SparseColumnKeys
(
_SparseColumn
):
class
_SparseColumnKeys
(
_SparseColumn
):
"""See `sparse_column_with_keys`."""
"""See `sparse_column_with_keys`."""
def
__new__
(
cls
,
column_name
,
keys
,
default_value
=-
1
,
combiner
=
"sum"
,
dtype
=
dtypes
.
string
):
if
(
not
dtype
.
is_integer
)
and
(
dtype
!=
dtypes
.
string
):
raise
TypeError
(
"Only integer and string are currently supported."
)
return
super
(
_SparseColumnKeys
,
cls
).
__new__
(
cls
,
column_name
,
combiner
=
combiner
,
lookup_config
=
_SparseIdLookupConfig
(
keys
=
keys
,
vocab_size
=
len
(
keys
),
default_value
=
default_value
),
dtype
=
dtype
)
def
insert_transformed_feature
(
self
,
columns_to_tensors
):
def
insert_transformed_feature
(
self
,
columns_to_tensors
):
"""Handles sparse column to id conversion."""
"""Handles sparse column to id conversion."""
input_tensor
=
self
.
_get_input_sparse_tensor
(
columns_to_tensors
)
input_tensor
=
self
.
_get_input_sparse_tensor
(
columns_to_tensors
)
...
@@ -614,7 +582,7 @@ def sparse_column_with_keys(
...
@@ -614,7 +582,7 @@ def sparse_column_with_keys(
Args:
Args:
column_name: A string defining sparse column name.
column_name: A string defining sparse column name.
keys: A list defining vocabulary. Must be castable to `dtype`.
keys: A list
or tuple
defining vocabulary. Must be castable to `dtype`.
default_value: The value to use for out-of-vocabulary feature values.
default_value: The value to use for out-of-vocabulary feature values.
Default is -1.
Default is -1.
combiner: A string specifying how to reduce if the sparse column is
combiner: A string specifying how to reduce if the sparse column is
...
@@ -630,38 +598,18 @@ def sparse_column_with_keys(
...
@@ -630,38 +598,18 @@ def sparse_column_with_keys(
Returns:
Returns:
A _SparseColumnKeys with keys configuration.
A _SparseColumnKeys with keys configuration.
"""
"""
keys
=
tuple
(
keys
)
return
_SparseColumnKeys
(
return
_SparseColumnKeys
(
column_name
,
tuple
(
keys
),
default_value
=
default_value
,
combiner
=
combiner
,
column_name
,
lookup_config
=
_SparseIdLookupConfig
(
keys
=
keys
,
vocab_size
=
len
(
keys
),
default_value
=
default_value
),
combiner
=
combiner
,
dtype
=
dtype
)
dtype
=
dtype
)
class
_SparseColumnVocabulary
(
_SparseColumn
):
class
_SparseColumnVocabulary
(
_SparseColumn
):
"""See `sparse_column_with_vocabulary_file`."""
"""See `sparse_column_with_vocabulary_file`."""
def
__new__
(
cls
,
column_name
,
vocabulary_file
,
num_oov_buckets
=
0
,
vocab_size
=
None
,
default_value
=-
1
,
combiner
=
"sum"
,
dtype
=
dtypes
.
string
):
if
dtype
!=
dtypes
.
string
and
not
dtype
.
is_integer
:
raise
ValueError
(
"dtype must be string or integer. "
"dtype: {}, column_name: {}"
.
format
(
dtype
,
column_name
))
return
super
(
_SparseColumnVocabulary
,
cls
).
__new__
(
cls
,
column_name
,
combiner
=
combiner
,
lookup_config
=
_SparseIdLookupConfig
(
vocabulary_file
=
vocabulary_file
,
num_oov_buckets
=
num_oov_buckets
,
vocab_size
=
vocab_size
,
default_value
=
default_value
),
dtype
=
dtype
)
def
insert_transformed_feature
(
self
,
columns_to_tensors
):
def
insert_transformed_feature
(
self
,
columns_to_tensors
):
"""Handles sparse column to id conversion."""
"""Handles sparse column to id conversion."""
st
=
self
.
_get_input_sparse_tensor
(
columns_to_tensors
)
st
=
self
.
_get_input_sparse_tensor
(
columns_to_tensors
)
...
@@ -726,10 +674,11 @@ def sparse_column_with_vocabulary_file(column_name,
...
@@ -726,10 +674,11 @@ def sparse_column_with_vocabulary_file(column_name,
return
_SparseColumnVocabulary
(
return
_SparseColumnVocabulary
(
column_name
,
column_name
,
vocabulary_file
,
lookup_config
=
_SparseIdLookupConfig
(
num_oov_buckets
=
num_oov_buckets
,
vocabulary_file
=
vocabulary_file
,
vocab_size
=
vocab_size
,
num_oov_buckets
=
num_oov_buckets
,
default_value
=
default_value
,
vocab_size
=
vocab_size
,
default_value
=
default_value
),
combiner
=
combiner
,
combiner
=
combiner
,
dtype
=
dtype
)
dtype
=
dtype
)
...
...
tensorflow/contrib/layers/python/layers/feature_column_test.py
浏览文件 @
29293fb6
...
@@ -554,6 +554,55 @@ class FeatureColumnTest(test.TestCase):
...
@@ -554,6 +554,55 @@ class FeatureColumnTest(test.TestCase):
sparse_result
=
sess
.
run
(
sparse_output
)
sparse_result
=
sess
.
run
(
sparse_output
)
self
.
assertEquals
(
expected_shape
,
list
(
sparse_result
.
dense_shape
))
self
.
assertEquals
(
expected_shape
,
list
(
sparse_result
.
dense_shape
))
def
testSparseColumnIntegerizedDeepCopy
(
self
):
"""Tests deepcopy of sparse_column_with_integerized_feature."""
column
=
fc
.
sparse_column_with_integerized_feature
(
"a"
,
10
)
self
.
assertEqual
(
"a"
,
column
.
name
)
column_copy
=
copy
.
deepcopy
(
column
)
self
.
assertEqual
(
"a"
,
column_copy
.
name
)
self
.
assertEqual
(
10
,
column_copy
.
bucket_size
)
self
.
assertTrue
(
column_copy
.
is_integerized
)
def
testSparseColumnHashBucketDeepCopy
(
self
):
"""Tests deepcopy of sparse_column_with_hash_bucket."""
column
=
fc
.
sparse_column_with_hash_bucket
(
"a"
,
10
)
self
.
assertEqual
(
"a"
,
column
.
name
)
column_copy
=
copy
.
deepcopy
(
column
)
self
.
assertEqual
(
"a"
,
column_copy
.
name
)
self
.
assertEqual
(
10
,
column_copy
.
bucket_size
)
self
.
assertFalse
(
column_copy
.
is_integerized
)
def
testSparseColumnKeysDeepCopy
(
self
):
"""Tests deepcopy of sparse_column_with_keys."""
column
=
fc
.
sparse_column_with_keys
(
"a"
,
keys
=
[
"key0"
,
"key1"
,
"key2"
])
self
.
assertEqual
(
"a"
,
column
.
name
)
column_copy
=
copy
.
deepcopy
(
column
)
self
.
assertEqual
(
"a"
,
column_copy
.
name
)
self
.
assertEqual
(
fc
.
_SparseIdLookupConfig
(
# pylint: disable=protected-access
keys
=
(
"key0"
,
"key1"
,
"key2"
),
vocab_size
=
3
,
default_value
=-
1
),
column_copy
.
lookup_config
)
self
.
assertFalse
(
column_copy
.
is_integerized
)
def
testSparseColumnVocabularyDeepCopy
(
self
):
"""Tests deepcopy of sparse_column_with_vocabulary_file."""
column
=
fc
.
sparse_column_with_vocabulary_file
(
"a"
,
vocabulary_file
=
"path_to_file"
,
vocab_size
=
3
)
self
.
assertEqual
(
"a"
,
column
.
name
)
column_copy
=
copy
.
deepcopy
(
column
)
self
.
assertEqual
(
"a"
,
column_copy
.
name
)
self
.
assertEqual
(
fc
.
_SparseIdLookupConfig
(
# pylint: disable=protected-access
vocabulary_file
=
"path_to_file"
,
num_oov_buckets
=
0
,
vocab_size
=
3
,
default_value
=-
1
),
column_copy
.
lookup_config
)
self
.
assertFalse
(
column_copy
.
is_integerized
)
def
testCreateFeatureSpec
(
self
):
def
testCreateFeatureSpec
(
self
):
sparse_col
=
fc
.
sparse_column_with_hash_bucket
(
sparse_col
=
fc
.
sparse_column_with_hash_bucket
(
"sparse_column"
,
hash_bucket_size
=
100
)
"sparse_column"
,
hash_bucket_size
=
100
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录