提交 dee19f6e 编写于 作者: Q Qiao Longfei

add preprocess.py

上级 b3943336
...@@ -15,11 +15,7 @@ ...@@ -15,11 +15,7 @@
``` ```
## 运行环境 ## 运行环境
需要先安装PaddlePaddle Fluid,然后运行: 需要先安装PaddlePaddle Fluid
```shell
pip install -r requirements.txt
```
## 数据集 ## 数据集
本文使用的是Kaggle公司举办的[展示广告竞赛](https://www.kaggle.com/c/criteo-display-ad-challenge/)中所使用的Criteo数据集。 本文使用的是Kaggle公司举办的[展示广告竞赛](https://www.kaggle.com/c/criteo-display-ad-challenge/)中所使用的Criteo数据集。
......
...@@ -20,11 +20,7 @@ factorization machines, please refer to the paper [factorization ...@@ -20,11 +20,7 @@ factorization machines, please refer to the paper [factorization
machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf) machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf)
## Environment ## Environment
You should install PaddlePaddle Fluid first, and run: You should install PaddlePaddle Fluid first.
```shell
pip install -r requirements.txt
```
## Dataset ## Dataset
This example uses Criteo dataset which was used for the [Display Advertising This example uses Criteo dataset which was used for the [Display Advertising
......
# -*- coding: utf-8 -*
import re
import argparse
def parse_args():
parser = argparse.ArgumentParser(
description="Paddle Fluid word2 vector preprocess")
parser.add_argument(
'--data_path',
type=str,
required=True,
help="The path of training dataset")
parser.add_argument(
'--dict_path',
type=str,
default='./dict',
help="The path of generated dict")
parser.add_argument(
'--freq',
type=int,
default=5,
help="If the word count is less then freq, it will be removed from dict")
return parser.parse_args()
def preprocess(data_path, dict_path, freq):
"""
proprocess the data, generate dictionary and save into dict_path.
:param data_path: the input data path.
:param dict_path: the generated dict path. the data in dict is "word count"
:param freq:
:return:
"""
# word to count
word_count = dict()
with open(data_path) as f:
for line in f:
line = line.lower()
line = re.sub("[^0-9a-z ]", "", line)
words = line.split()
for item in words:
if item in word_count:
word_count[item] = word_count[item] + 1
else:
word_count[item] = 1
item_to_remove = []
for item in word_count:
if word_count[item] <= freq:
item_to_remove.append(item)
for item in item_to_remove:
del word_count[item]
with open(dict_path, 'w+') as f:
for k, v in word_count.items():
f.write(str(k) + " " + str(v) + '\n')
if __name__ == "__main__":
args = parse_args()
preprocess(args.data_path, args.dict_path, args.freq)
...@@ -6,13 +6,16 @@ import random ...@@ -6,13 +6,16 @@ import random
from collections import Counter from collections import Counter
""" """
refs: https://github.com/NELSONZHAO/zhihu/blob/master/skip_gram/Skip-Gram-English-Corpus.ipynb refs: https://github.com/NELSONZHAO/zhihu/blob/master/skip_gram/Skip-Gram-English-Corpus.ipynb
text8 dataset
http://mattmahoney.net/dc/textdata.html
""" """
with open('data/text8.txt') as f: with open('data/text8.txt') as f:
text = f.read() text = f.read()
# 定义函数来完成数据的预处理
def preprocess(text, freq=5): def preprocess(text, freq=5):
''' '''
对文本进行预处理 对文本进行预处理
...@@ -52,7 +55,6 @@ print(words[:20]) ...@@ -52,7 +55,6 @@ print(words[:20])
# 构建映射表 # 构建映射表
vocab = set(words) vocab = set(words)
vocab_to_int = {w: c for c, w in enumerate(vocab)} vocab_to_int = {w: c for c, w in enumerate(vocab)}
int_to_vocab = {c: w for c, w in enumerate(vocab)}
dict_size = len(set(words)) dict_size = len(set(words))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册