Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
3f2f6faa
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3f2f6faa
编写于
5月 29, 2020
作者:
W
wukesong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
wide&deep data process
上级
c8d910b9
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
268 addition
and
0 deletion
+268
-0
model_zoo/wide_and_deep/src/process_data.py
model_zoo/wide_and_deep/src/process_data.py
+268
-0
未找到文件。
model_zoo/wide_and_deep/src/process_data.py
0 → 100644
浏览文件 @
3f2f6faa
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Criteo data process
"""
import
os
import
pickle
import
collections
import
argparse
import
numpy
as
np
import
pandas
as
pd
TRAIN_LINE_COUNT
=
45840617
TEST_LINE_COUNT
=
6042135
class
CriteoStatsDict
():
"""create data dict"""
def
__init__
(
self
):
self
.
field_size
=
39
# value_1-13; cat_1-26;
self
.
val_cols
=
[
"val_{}"
.
format
(
i
+
1
)
for
i
in
range
(
13
)]
self
.
cat_cols
=
[
"cat_{}"
.
format
(
i
+
1
)
for
i
in
range
(
26
)]
#
self
.
val_min_dict
=
{
col
:
0
for
col
in
self
.
val_cols
}
self
.
val_max_dict
=
{
col
:
0
for
col
in
self
.
val_cols
}
self
.
cat_count_dict
=
{
col
:
collections
.
defaultdict
(
int
)
for
col
in
self
.
cat_cols
}
#
self
.
oov_prefix
=
"OOV_"
self
.
cat2id_dict
=
{}
self
.
cat2id_dict
.
update
({
col
:
i
for
i
,
col
in
enumerate
(
self
.
val_cols
)})
self
.
cat2id_dict
.
update
({
self
.
oov_prefix
+
col
:
i
+
len
(
self
.
val_cols
)
for
i
,
col
in
enumerate
(
self
.
cat_cols
)})
#
def
stats_vals
(
self
,
val_list
):
"""vals status"""
assert
len
(
val_list
)
==
len
(
self
.
val_cols
)
def
map_max_min
(
i
,
val
):
key
=
self
.
val_cols
[
i
]
if
val
!=
""
:
if
float
(
val
)
>
self
.
val_max_dict
[
key
]:
self
.
val_max_dict
[
key
]
=
float
(
val
)
if
float
(
val
)
<
self
.
val_min_dict
[
key
]:
self
.
val_min_dict
[
key
]
=
float
(
val
)
#
for
i
,
val
in
enumerate
(
val_list
):
map_max_min
(
i
,
val
)
#
def
stats_cats
(
self
,
cat_list
):
assert
len
(
cat_list
)
==
len
(
self
.
cat_cols
)
def
map_cat_count
(
i
,
cat
):
key
=
self
.
cat_cols
[
i
]
self
.
cat_count_dict
[
key
][
cat
]
+=
1
#
for
i
,
cat
in
enumerate
(
cat_list
):
map_cat_count
(
i
,
cat
)
#
def
save_dict
(
self
,
output_path
,
prefix
=
""
):
with
open
(
os
.
path
.
join
(
output_path
,
"{}val_max_dict.pkl"
.
format
(
prefix
)),
"wb"
)
as
file_wrt
:
pickle
.
dump
(
self
.
val_max_dict
,
file_wrt
)
with
open
(
os
.
path
.
join
(
output_path
,
"{}val_min_dict.pkl"
.
format
(
prefix
)),
"wb"
)
as
file_wrt
:
pickle
.
dump
(
self
.
val_min_dict
,
file_wrt
)
with
open
(
os
.
path
.
join
(
output_path
,
"{}cat_count_dict.pkl"
.
format
(
prefix
)),
"wb"
)
as
file_wrt
:
pickle
.
dump
(
self
.
cat_count_dict
,
file_wrt
)
#
def
load_dict
(
self
,
dict_path
,
prefix
=
""
):
with
open
(
os
.
path
.
join
(
dict_path
,
"{}val_max_dict.pkl"
.
format
(
prefix
)),
"rb"
)
as
file_wrt
:
self
.
val_max_dict
=
pickle
.
load
(
file_wrt
)
with
open
(
os
.
path
.
join
(
dict_path
,
"{}val_min_dict.pkl"
.
format
(
prefix
)),
"rb"
)
as
file_wrt
:
self
.
val_min_dict
=
pickle
.
load
(
file_wrt
)
with
open
(
os
.
path
.
join
(
dict_path
,
"{}cat_count_dict.pkl"
.
format
(
prefix
)),
"rb"
)
as
file_wrt
:
self
.
cat_count_dict
=
pickle
.
load
(
file_wrt
)
print
(
"val_max_dict.items()[:50]: {}"
.
format
(
list
(
self
.
val_max_dict
.
items
())))
print
(
"val_min_dict.items()[:50]: {}"
.
format
(
list
(
self
.
val_min_dict
.
items
())))
#
#
def
get_cat2id
(
self
,
threshold
=
100
):
"""get cat to id"""
# before_all_count = 0
# after_all_count = 0
for
key
,
cat_count_d
in
self
.
cat_count_dict
.
items
():
new_cat_count_d
=
dict
(
filter
(
lambda
x
:
x
[
1
]
>
threshold
,
cat_count_d
.
items
()))
for
cat_str
,
_
in
new_cat_count_d
.
items
():
self
.
cat2id_dict
[
key
+
"_"
+
cat_str
]
=
len
(
self
.
cat2id_dict
)
# print("before_all_count: {}".format(before_all_count)) # before_all_count: 33762577
# print("after_all_count: {}".format(after_all_count)) # after_all_count: 184926
print
(
"cat2id_dict.size: {}"
.
format
(
len
(
self
.
cat2id_dict
)))
print
(
"cat2id_dict.items()[:50]: {}"
.
format
(
self
.
cat2id_dict
.
items
()[:
50
]))
#
def
map_cat2id
(
self
,
values
,
cats
):
"""map cat to id"""
def
minmax_sclae_value
(
i
,
val
):
# min_v = float(self.val_min_dict["val_{}".format(i+1)])
max_v
=
float
(
self
.
val_max_dict
[
"val_{}"
.
format
(
i
+
1
)])
# return (float(val) - min_v) * 1.0 / (max_v - min_v)
return
float
(
val
)
*
1.0
/
max_v
id_list
=
[]
weight_list
=
[]
for
i
,
val
in
enumerate
(
values
):
if
val
==
""
:
id_list
.
append
(
i
)
weight_list
.
append
(
0
)
else
:
key
=
"val_{}"
.
format
(
i
+
1
)
id_list
.
append
(
self
.
cat2id_dict
[
key
])
weight_list
.
append
(
minmax_sclae_value
(
i
,
float
(
val
)))
#
for
i
,
cat_str
in
enumerate
(
cats
):
key
=
"cat_{}"
.
format
(
i
+
1
)
+
"_"
+
cat_str
if
key
in
self
.
cat2id_dict
:
id_list
.
append
(
self
.
cat2id_dict
[
key
])
else
:
id_list
.
append
(
self
.
cat2id_dict
[
self
.
oov_prefix
+
"cat_{}"
.
format
(
i
+
1
)])
weight_list
.
append
(
1.0
)
return
id_list
,
weight_list
#
def
mkdir_path
(
file_path
):
if
not
os
.
path
.
exists
(
file_path
):
os
.
makedirs
(
file_path
)
#
def
statsdata
(
data_file_path
,
output_path
,
criteo_stats
):
"""data status"""
with
open
(
data_file_path
,
encoding
=
"utf-8"
)
as
file_in
:
errorline_list
=
[]
count
=
0
for
line
in
file_in
:
count
+=
1
line
=
line
.
strip
(
"
\n
"
)
items
=
line
.
strip
(
"
\t
"
)
if
len
(
items
)
!=
40
:
errorline_list
.
append
(
count
)
print
(
"line: {}"
.
format
(
line
))
continue
if
count
%
1000000
==
0
:
print
(
"Have handle {}w lines."
.
format
(
count
//
10000
))
# if count % 5000000 == 0:
# print("Have handle {}w lines.".format(count//10000))
# label = items[0]
values
=
items
[
1
:
14
]
cats
=
items
[
14
:]
assert
len
(
values
)
==
13
,
"value.size: {}"
.
format
(
len
(
values
))
assert
len
(
cats
)
==
26
,
"cat.size: {}"
.
format
(
len
(
cats
))
criteo_stats
.
stats_vals
(
values
)
criteo_stats
.
stats_cats
(
cats
)
criteo_stats
.
save_dict
(
output_path
)
#
def
add_write
(
file_path
,
wr_str
):
with
open
(
file_path
,
"a"
,
encoding
=
"utf-8"
)
as
file_out
:
file_out
.
write
(
wr_str
+
"
\n
"
)
#
def
random_split_trans2h5
(
in_file_path
,
output_path
,
criteo_stats
,
part_rows
=
2000000
,
test_size
=
0.1
,
seed
=
2020
):
"""random split trans2h5"""
test_size
=
int
(
TRAIN_LINE_COUNT
*
test_size
)
# train_size = TRAIN_LINE_COUNT - test_size
all_indices
=
[
i
for
i
in
range
(
TRAIN_LINE_COUNT
)]
np
.
random
.
seed
(
seed
)
np
.
random
.
shuffle
(
all_indices
)
print
(
"all_indices.size: {}"
.
format
(
len
(
all_indices
)))
# lines_count_dict = collections.defaultdict(int)
test_indices_set
=
set
(
all_indices
[:
test_size
])
print
(
"test_indices_set.size: {}"
.
format
(
len
(
test_indices_set
)))
print
(
"------"
*
10
+
"
\n
"
*
2
)
train_feature_file_name
=
os
.
path
.
join
(
output_path
,
"train_input_part_{}.h5"
)
train_label_file_name
=
os
.
path
.
join
(
output_path
,
"train_output_part_{}.h5"
)
test_feature_file_name
=
os
.
path
.
join
(
output_path
,
"test_input_part_{}.h5"
)
test_label_file_name
=
os
.
path
.
join
(
output_path
,
"test_input_part_{}.h5"
)
train_feature_list
=
[]
train_label_list
=
[]
test_feature_list
=
[]
test_label_list
=
[]
with
open
(
in_file_path
,
encoding
=
"utf-8"
)
as
file_in
:
count
=
0
train_part_number
=
0
test_part_number
=
0
for
i
,
line
in
enumerate
(
file_in
):
count
+=
1
if
count
%
1000000
==
0
:
print
(
"Have handle {}w lines."
.
format
(
count
//
10000
))
line
=
line
.
strip
(
"
\n
"
)
items
=
line
.
split
(
"
\t
"
)
if
len
(
items
)
!=
40
:
continue
label
=
float
(
items
[
0
])
values
=
items
[
1
:
14
]
cats
=
items
[
14
:]
assert
len
(
values
)
==
13
,
"value.size: {}"
.
format
(
len
(
values
))
assert
len
(
cats
)
==
26
,
"cat.size: {}"
.
format
(
len
(
cats
))
ids
,
wts
=
criteo_stats
.
map_cat2id
(
values
,
cats
)
if
i
not
in
test_indices_set
:
train_feature_list
.
append
(
ids
+
wts
)
train_label_list
.
append
(
label
)
else
:
test_feature_list
.
append
(
ids
+
wts
)
test_label_list
.
append
(
label
)
if
train_label_list
and
(
len
(
train_label_list
)
%
part_rows
==
0
):
pd
.
DataFrame
(
np
.
asarray
(
train_feature_list
)).
to_hdf
(
train_feature_file_name
.
format
(
train_part_number
),
key
=
"fixed"
)
pd
.
DataFrame
(
np
.
asarray
(
train_label_list
)).
to_hdf
(
train_label_file_name
.
format
(
train_part_number
),
key
=
"fixed"
)
train_feature_list
=
[]
train_label_list
=
[]
train_part_number
+=
1
if
test_label_list
and
(
len
(
test_label_list
)
%
part_rows
==
0
):
pd
.
DataFrame
(
np
.
asarray
(
test_feature_list
)).
to_hdf
(
test_feature_file_name
.
format
(
test_part_number
),
key
=
"fixed"
)
pd
.
DataFrame
(
np
.
asarray
(
test_label_list
)).
to_hdf
(
test_label_file_name
.
format
(
test_part_number
),
key
=
"fixed"
)
test_feature_list
=
[]
test_label_list
=
[]
test_part_number
+=
1
#
if
train_label_list
:
pd
.
DataFrame
(
np
.
asarray
(
train_feature_list
)).
to_hdf
(
train_feature_file_name
.
format
(
train_part_number
),
key
=
"fixed"
)
pd
.
DataFrame
(
np
.
asarray
(
train_label_list
)).
to_hdf
(
train_label_file_name
.
format
(
train_part_number
),
key
=
"fixed"
)
if
test_label_list
:
pd
.
DataFrame
(
np
.
asarray
(
test_feature_list
)).
to_hdf
(
test_feature_file_name
.
format
(
test_part_number
),
key
=
"fixed"
)
pd
.
DataFrame
(
np
.
asarray
(
test_label_list
)).
to_hdf
(
test_label_file_name
.
format
(
test_part_number
),
key
=
"fixed"
)
#
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Get and Process datasets"
)
parser
.
add_argument
(
"--raw_data_path"
,
default
=
"/opt/npu/data/origin_criteo_data/"
,
help
=
"The path to save dataset"
)
parser
.
add_argument
(
"--output_path"
,
default
=
"/opt/npu/data/origin_criteo_data/h5_data/"
,
help
=
"The path to save dataset"
)
args
,
_
=
parser
.
parse_known_args
()
base_path
=
args
.
raw_data_path
criteo_stat
=
CriteoStatsDict
()
# step 1, stats the vocab and normalize value
datafile_path
=
base_path
+
"train_small.txt"
stats_out_path
=
base_path
+
"stats_dict/"
mkdir_path
(
stats_out_path
)
statsdata
(
datafile_path
,
stats_out_path
,
criteo_stat
)
print
(
"------"
*
10
)
criteo_stat
.
load_dict
(
dict_path
=
stats_out_path
,
prefix
=
""
)
criteo_stat
.
get_cat2id
(
threshold
=
100
)
# step 2, transform data trans2h5; version 2: np.random.shuffle
infile_path
=
base_path
+
"train_small.txt"
mkdir_path
(
args
.
output_path
)
random_split_trans2h5
(
infile_path
,
args
.
output_path
,
criteo_stat
,
part_rows
=
2000000
,
test_size
=
0.1
,
seed
=
2020
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录