提交 f145df7b 编写于 作者: Y yinhaofeng

fm add data

上级 3b0a9cf5
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import sys
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
TOOLS_PATH = os.path.join(LOCAL_PATH, "..", "..", "tools")
sys.path.append(TOOLS_PATH)
from paddlerec.tools.tools import download_file_and_uncompress, download_file
if __name__ == '__main__':
url = "https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz"
url2 = "https://paddlerec.bj.bcebos.com/deepfm%2Ffeat_dict_10.pkl2"
print("download and extract starting...")
download_file_and_uncompress(url)
download_file(url2, "./deepfm%2Ffeat_dict_10.pkl2", True)
print("download and extract finished")
print("preprocessing...")
os.system("python preprocess.py")
print("preprocess done")
shutil.rmtree("raw_data")
print("done")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddlerec.core.utils import envs
import paddle.fluid.incubate.data_generator as dg
try:
import cPickle as pickle
except ImportError:
import pickle
class Reader(dg.MultiSlotDataGenerator):
def __init__(self, config):
dg.MultiSlotDataGenerator.__init__(self)
_config = envs.load_yaml(config)
def init(self):
self.cont_min_ = [0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
self.cont_max_ = [
5775, 257675, 65535, 969, 23159456, 431037, 56311, 6047, 29019, 46,
231, 4008, 7393
]
self.cont_diff_ = [
self.cont_max_[i] - self.cont_min_[i]
for i in range(len(self.cont_min_))
]
self.continuous_range_ = range(1, 14)
self.categorical_range_ = range(14, 40)
# load preprocessed feature dict
self.feat_dict_name = "sample_data/feat_dict_10.pkl2"
self.feat_dict_ = pickle.load(open(self.feat_dict_name, 'rb'))
def _process_line(self, line):
features = line.rstrip('\n').split('\t')
feat_idx = []
feat_value = []
for idx in self.continuous_range_:
if features[idx] == '':
feat_idx.append(0)
feat_value.append(0.0)
else:
feat_idx.append(self.feat_dict_[idx])
feat_value.append(
(float(features[idx]) - self.cont_min_[idx - 1]) /
self.cont_diff_[idx - 1])
for idx in self.categorical_range_:
if features[idx] == '' or features[idx] not in self.feat_dict_:
feat_idx.append(0)
feat_value.append(0.0)
else:
feat_idx.append(self.feat_dict_[features[idx]])
feat_value.append(1.0)
label = [int(features[0])]
return feat_idx, feat_value, label
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def data_iter():
feat_idx, feat_value, label = self._process_line(line)
s = ""
for i in [('feat_idx', feat_idx), ('feat_value', feat_value),
('label', label)]:
k = i[0]
v = i[1]
for j in v:
s += " " + k + ":" + str(j)
print(s.strip())
yield None
return data_iter
reader = Reader(
"../config.yaml") # run this file in original folder to find config.yaml
reader.init()
reader.run_from_stdin()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy
from collections import Counter
import shutil
import pickle
def get_raw_data():
if not os.path.isdir('raw_data'):
os.mkdir('raw_data')
fin = open('train.txt', 'r')
fout = open('raw_data/part-0', 'w')
for line_idx, line in enumerate(fin):
if line_idx % 200000 == 0 and line_idx != 0:
fout.close()
cur_part_idx = int(line_idx / 200000)
fout = open('raw_data/part-' + str(cur_part_idx), 'w')
fout.write(line)
fout.close()
fin.close()
def split_data():
split_rate_ = 0.9
dir_train_file_idx_ = 'aid_data/train_file_idx.txt'
filelist_ = [
'raw_data/part-%d' % x for x in range(len(os.listdir('raw_data')))
]
if not os.path.exists(dir_train_file_idx_):
train_file_idx = list(
numpy.random.choice(
len(filelist_), int(len(filelist_) * split_rate_), False))
with open(dir_train_file_idx_, 'w') as fout:
fout.write(str(train_file_idx))
else:
with open(dir_train_file_idx_, 'r') as fin:
train_file_idx = eval(fin.read())
for idx in range(len(filelist_)):
if idx in train_file_idx:
shutil.move(filelist_[idx], 'train_data')
else:
shutil.move(filelist_[idx], 'test_data')
def get_feat_dict():
freq_ = 10
dir_feat_dict_ = 'aid_data/feat_dict_' + str(freq_) + '.pkl2'
continuous_range_ = range(1, 14)
categorical_range_ = range(14, 40)
if not os.path.exists(dir_feat_dict_):
# Count the number of occurrences of discrete features
feat_cnt = Counter()
with open('train.txt', 'r') as fin:
for line_idx, line in enumerate(fin):
if line_idx % 100000 == 0:
print('generating feature dict', line_idx / 45000000)
features = line.rstrip('\n').split('\t')
for idx in categorical_range_:
if features[idx] == '': continue
feat_cnt.update([features[idx]])
# Only retain discrete features with high frequency
dis_feat_set = set()
for feat, ot in feat_cnt.items():
if ot >= freq_:
dis_feat_set.add(feat)
# Create a dictionary for continuous and discrete features
feat_dict = {}
tc = 1
# Continuous features
for idx in continuous_range_:
feat_dict[idx] = tc
tc += 1
for feat in dis_feat_set:
feat_dict[feat] = tc
tc += 1
# Save dictionary
with open(dir_feat_dict_, 'wb') as fout:
pickle.dump(feat_dict, fout, protocol=2)
print('args.num_feat ', len(feat_dict) + 1)
if __name__ == '__main__':
if not os.path.isdir('train_data'):
os.mkdir('train_data')
if not os.path.isdir('test_data'):
os.mkdir('test_data')
if not os.path.isdir('aid_data'):
os.mkdir('aid_data')
get_raw_data()
split_data()
get_feat_dict()
print('Done!')
python download_preprocess.py
mv ./deepfm%2Ffeat_dict_10.pkl2 sample_data/feat_dict_10.pkl2
mkdir slot_train_data
for i in `ls ./train_data`
do
cat train_data/$i | python get_slot_data.py > slot_train_data/$i
done
mkdir slot_test_data
for i in `ls ./test_data`
do
cat test_data/$i | python get_slot_data.py > slot_test_data/$i
done
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册