提交 060d7e47 编写于 作者: M malin10

Merge branch 'master' of https://github.com/PaddlePaddle/PaddleRec into model_fix

......@@ -19,10 +19,16 @@ import copy
import os
import subprocess
import warnings
import sys
import logging
from paddlerec.core.engine.engine import Engine
from paddlerec.core.factory import TrainerFactory
from paddlerec.core.utils import envs
import paddlerec.core.engine.cluster_utils as cluster_utils
logger = logging.getLogger("root")
logger.propagate = False
class ClusterEngine(Engine):
......@@ -47,8 +53,38 @@ class ClusterEngine(Engine):
self.backend))
def start_worker_procs(self):
trainer = TrainerFactory.create(self.trainer)
trainer.run()
if (envs.get_runtime_environ("fleet_mode") == "COLLECTIVE"):
#trainer_ports = os.getenv("TRAINER_PORTS", None).split(",")
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
if cuda_visible_devices is None or cuda_visible_devices == "":
selected_gpus = range(int(os.getenv("TRAINER_GPU_CARD_COUNT")))
else:
# change selected_gpus into relative values
# e.g. CUDA_VISIBLE_DEVICES=4,5,6,7; args.selected_gpus=4,5,6,7;
# therefore selected_gpus=0,1,2,3
cuda_visible_devices_list = cuda_visible_devices.split(',')
for x in range(int(os.getenv("TRAINER_GPU_CARD_COUNT"))):
assert x in cuda_visible_devices_list, "Can't find "\
"your selected_gpus %s in CUDA_VISIBLE_DEVICES[%s]."\
% (x, cuda_visible_devices)
selected_gpus = [cuda_visible_devices_list.index(x)]
print("selected_gpus:{}".format(selected_gpus))
factory = "paddlerec.core.factory"
cmd = [sys.executable, "-u", "-m", factory, self.trainer]
logs_dir = envs.get_runtime_environ("log_dir")
print("use_paddlecloud_flag:{}".format(
cluster_utils.use_paddlecloud()))
if cluster_utils.use_paddlecloud():
cluster, pod = cluster_utils.get_cloud_cluster(selected_gpus)
logger.info("get cluster from cloud:{}".format(cluster))
procs = cluster_utils.start_local_trainers(
cluster, pod, cmd, log_dir=logs_dir)
print("cluster:{}".format(cluster))
print("pod:{}".format(pod))
else:
trainer = TrainerFactory.create(self.trainer)
trainer.run()
def start_master_procs(self):
if self.backend == "PADDLECLOUD":
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import logging
import socket
import time
import os
import signal
import copy
import sys
import subprocess
from contextlib import closing
import socket
logger = logging.getLogger("root")
logger.propagate = False
class Cluster(object):
def __init__(self, hdfs):
self.job_server = None
self.pods = []
self.hdfs = None
self.job_stage_flag = None
def __str__(self):
return "job_server:{} pods:{} job_stage_flag:{} hdfs:{}".format(
self.job_server, [str(pod) for pod in self.pods],
self.job_stage_flag, self.hdfs)
def __eq__(self, cluster):
if len(self.pods) != len(cluster.pods):
return False
for a, b in zip(self.pods, cluster.pods):
if a != b:
return False
if self.job_stage_flag != cluster.job_stage_flag:
return False
return True
def __ne__(self, cluster):
return not self.__eq__(cluster)
def update_pods(cluster):
self.pods = copy.copy(cluster.pods)
def trainers_nranks(self):
return len(self.trainers_endpoints())
def pods_nranks(self):
return len(self.pods)
def trainers_endpoints(self):
r = []
for pod in self.pods:
for t in pod.trainers:
r.append(t.endpoint)
return r
def pods_endpoints(self):
r = []
for pod in self.pods:
ep = "{}:{}".format(pod.addr, pod.port)
assert pod.port != None and pod.addr != None, "{} not a valid endpoint".format(
ep)
r.append(ep)
return r
def get_pod_by_id(self, pod_id):
for pod in self.pods:
if str(pod_id) == str(pod.id):
return pod
return None
class JobServer(object):
def __init__(self):
self.endpoint = None
def __str__(self):
return "{}".format(self.endpoint)
def __eq__(self, j):
return self.endpint == j.endpoint
def __ne__(self, j):
return not self == j
class Trainer(object):
def __init__(self):
self.gpus = []
self.endpoint = None
self.rank = None
def __str__(self):
return "gpu:{} endpoint:{} rank:{}".format(self.gpus, self.endpoint,
self.rank)
def __eq__(self, t):
if len(self.gpus) != len(t.gpus):
return False
if self.endpoint != t.endpoint or \
self.rank != t.rank:
return False
for a, b in zip(self.gpus, t.gpus):
if a != b:
return False
return True
def __ne__(self, t):
return not self == t
def rank(self):
return self.rank
class Pod(object):
def __init__(self):
self.rank = None
self.id = None
self.addr = None
self.port = None
self.trainers = []
self.gpus = []
def __str__(self):
return "rank:{} id:{} addr:{} port:{} visible_gpu:{} trainers:{}".format(
self.rank, self.id, self.addr, self.port, self.gpus,
[str(t) for t in self.trainers])
def __eq__(self, pod):
if self.rank != pod.rank or \
self.id != pod.id or \
self.addr != pod.addr or \
self.port != pod.port:
logger.debug("pod {} != pod".format(self, pod))
return False
if len(self.trainers) != len(pod.trainers):
logger.debug("trainers {} != {}".format(self.trainers,
pod.trainers))
return False
for i in range(len(self.trainers)):
if self.trainers[i] != pod.trainers[i]:
logger.debug("trainer {} != {}".format(self.trainers[i],
pod.trainers[i]))
return False
return True
def __ne__(self, pod):
return not self == pod
def parse_response(self, res_pods):
pass
def rank(self):
return self.rank
def get_visible_gpus(self):
r = ""
for g in self.gpus:
r += "{},".format(g)
assert r != "", "this pod {} can't see any gpus".format(self)
r = r[:-1]
return r
def get_cluster(node_ips, node_ip, paddle_ports, selected_gpus):
assert type(paddle_ports) is list, "paddle_ports must be list"
cluster = Cluster(hdfs=None)
trainer_rank = 0
for node_rank, ip in enumerate(node_ips):
pod = Pod()
pod.rank = node_rank
pod.addr = ip
for i in range(len(selected_gpus)):
trainer = Trainer()
trainer.gpus.append(selected_gpus[i])
trainer.endpoint = "%s:%d" % (ip, paddle_ports[i])
trainer.rank = trainer_rank
trainer_rank += 1
pod.trainers.append(trainer)
cluster.pods.append(pod)
pod_rank = node_ips.index(node_ip)
return cluster, cluster.pods[pod_rank]
def get_cloud_cluster(selected_gpus, args_port=None):
#you can automatically get ip info while using paddlecloud multi nodes mode.
node_ips = os.getenv("PADDLE_TRAINERS")
assert node_ips is not None, "PADDLE_TRAINERS should not be None"
print("node_ips:{}".format(node_ips))
node_ip = os.getenv("POD_IP")
assert node_ip is not None, "POD_IP should not be None"
print("node_ip:{}".format(node_ip))
node_rank = os.getenv("PADDLE_TRAINER_ID")
assert node_rank is not None, "PADDLE_TRAINER_ID should not be None"
print("node_rank:{}".format(node_rank))
node_ips = node_ips.split(",")
num_nodes = len(node_ips)
node_rank = int(node_rank)
started_port = args_port
print("num_nodes:", num_nodes)
if num_nodes > 1:
try:
paddle_port = int(os.getenv("PADDLE_PORT", ""))
paddle_port_num = int(os.getenv("TRAINER_PORTS_NUM", ""))
if paddle_port_num >= len(
selected_gpus) and paddle_port != args_port:
logger.warning("Use Cloud specified port:{}.".format(
paddle_port))
started_port = paddle_port
except Exception as e:
print(e)
pass
if started_port is None:
started_port = 6170
logger.debug("parsed from args:node_ips:{} \
node_ip:{} node_rank:{} started_port:{}"
.format(node_ips, node_ip, node_rank, started_port))
ports = [x for x in range(started_port, started_port + len(selected_gpus))]
cluster, pod = get_cluster(node_ips, node_ip, ports, selected_gpus)
return cluster, cluster.pods[node_rank]
def use_paddlecloud():
node_ips = os.getenv("PADDLE_TRAINERS", None)
node_ip = os.getenv("POD_IP", None)
node_rank = os.getenv("PADDLE_TRAINER_ID", None)
if node_ips is None or node_ip is None or node_rank is None:
return False
else:
return True
class TrainerProc(object):
def __init__(self):
self.proc = None
self.log_fn = None
self.log_offset = None
self.rank = None
self.local_rank = None
self.cmd = None
def start_local_trainers(cluster, pod, cmd, log_dir=None):
current_env = copy.copy(os.environ.copy())
#paddle broadcast ncclUniqueId use socket, and
#proxy maybe make trainers unreachable, so delete them.
#if we set them to "", grpc will log error message "bad uri"
#so just delete them.
current_env.pop("http_proxy", None)
current_env.pop("https_proxy", None)
procs = []
for idx, t in enumerate(pod.trainers):
proc_env = {
"FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in t.gpus]),
"PADDLE_TRAINER_ID": "%d" % t.rank,
"PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint,
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints())
}
current_env.update(proc_env)
logger.debug("trainer proc env:{}".format(current_env))
# cmd = [sys.executable, "-u", training_script]
logger.info("start trainer proc:{} env:{}".format(cmd, proc_env))
fn = None
if log_dir is not None:
os.system("mkdir -p {}".format(log_dir))
fn = open("%s/workerlog.%d" % (log_dir, idx), "a")
proc = subprocess.Popen(cmd, env=current_env, stdout=fn, stderr=fn)
else:
proc = subprocess.Popen(cmd, env=current_env)
tp = TrainerProc()
tp.proc = proc
tp.rank = t.rank
tp.local_rank = idx
tp.log_fn = fn
tp.log_offset = fn.tell() if fn else None
tp.cmd = cmd
procs.append(proc)
return procs
......@@ -19,9 +19,14 @@ import copy
import os
import sys
import subprocess
import logging
from paddlerec.core.engine.engine import Engine
from paddlerec.core.utils import envs
import paddlerec.core.engine.cluster_utils as cluster_utils
logger = logging.getLogger("root")
logger.propagate = False
class LocalClusterEngine(Engine):
......@@ -97,42 +102,70 @@ class LocalClusterEngine(Engine):
stderr=fn,
cwd=os.getcwd())
procs.append(proc)
elif fleet_mode.upper() == "COLLECTIVE":
selected_gpus = self.envs["selected_gpus"].split(",")
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES")
if cuda_visible_devices is None or cuda_visible_devices == "":
selected_gpus = [
x.strip() for x in self.envs["selected_gpus"].split(",")
]
else:
# change selected_gpus into relative values
# e.g. CUDA_VISIBLE_DEVICES=4,5,6,7; args.selected_gpus=4,5,6,7;
# therefore selected_gpus=0,1,2,3
cuda_visible_devices_list = cuda_visible_devices.split(',')
for x in self.envs["selected_gpus"].split(","):
assert x in cuda_visible_devices_list, "Can't find "\
"your selected_gpus %s in CUDA_VISIBLE_DEVICES[%s]."\
% (x, cuda_visible_devices)
selected_gpus = [
cuda_visible_devices_list.index(x.strip())
for x in self.envs["selected_gpus"].split(",")
]
selected_gpus_num = len(selected_gpus)
for i in range(selected_gpus_num - 1):
while True:
new_port = envs.find_free_port()
if new_port not in ports:
ports.append(new_port)
break
user_endpoints = ",".join(["127.0.0.1:" + str(x) for x in ports])
factory = "paddlerec.core.factory"
cmd = [sys.executable, "-u", "-m", factory, self.trainer]
for i in range(selected_gpus_num):
current_env.update({
"PADDLE_TRAINER_ENDPOINTS": user_endpoints,
"PADDLE_CURRENT_ENDPOINTS": user_endpoints[i],
"PADDLE_TRAINERS_NUM": str(worker_num),
"TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(i),
"FLAGS_selected_gpus": str(selected_gpus[i]),
"PADDLEREC_GPU_NUMS": str(selected_gpus_num)
})
os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/worker.%d" % (logs_dir, i), "w")
log_fns.append(fn)
proc = subprocess.Popen(
cmd,
env=current_env,
stdout=fn,
stderr=fn,
cwd=os.getcwd())
procs.append(proc)
print("use_paddlecloud_flag:{}".format(
cluster_utils.use_paddlecloud()))
if cluster_utils.use_paddlecloud():
cluster, pod = cluster_utils.get_cloud_cluster(selected_gpus)
logger.info("get cluster from cloud:{}".format(cluster))
procs = cluster_utils.start_local_trainers(
cluster, pod, cmd, log_dir=logs_dir)
else:
# trainers_num = 1 or not use paddlecloud ips="a,b"
for i in range(selected_gpus_num - 1):
while True:
new_port = envs.find_free_port()
if new_port not in ports:
ports.append(new_port)
break
user_endpoints = ",".join(
["127.0.0.1:" + str(x) for x in ports])
for i in range(selected_gpus_num):
current_env.update({
"PADDLE_TRAINER_ENDPOINTS": user_endpoints,
"PADDLE_CURRENT_ENDPOINTS": user_endpoints[i],
"PADDLE_TRAINERS_NUM": str(worker_num),
"TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(i),
"FLAGS_selected_gpus": str(selected_gpus[i]),
"PADDLEREC_GPU_NUMS": str(selected_gpus_num)
})
os.system("mkdir -p {}".format(logs_dir))
fn = open("%s/worker.%d" % (logs_dir, i), "w")
log_fns.append(fn)
proc = subprocess.Popen(
cmd,
env=current_env,
stdout=fn,
stderr=fn,
cwd=os.getcwd())
procs.append(proc)
# only wait worker to finish here
for i, proc in enumerate(procs):
......
......@@ -49,10 +49,12 @@ runner:
save_checkpoint_path: "increment"
save_inference_path: "inference"
print_interval: 1
phases: [train]
- name: infer_runner
class: infer
init_model_path: "increment/1"
device: cpu
phases: [infer]
phase:
- name: train
......
......@@ -30,13 +30,12 @@
### 一键下载训练及测试数据
```bash
sh download_data.sh
sh run.sh
```
执行该脚本,会从国内源的服务器上下载Criteo数据集,并解压到指定文件夹。全量训练数据放置于`./train_data_full/`,全量测试数据放置于`./test_data_full/`,用于快速验证的训练数据与测试数据放置于`./train_data/``./test_data/`
进入models/rank/dnn/data目录下,执行该脚本,会从国内源的服务器上下载Criteo数据集,并解压到指定文件夹。原始的全量数据放置于`./train_data_full/`,原始的全量测试数据放置于`./test_data_full/`,原始的用于快速验证的训练数据与测试数据放置于`./train_data/``./test_data/`。处理后的全量训练数据放置于`./slot_train_data_full/`,处理后的全量测试数据放置于`./slot_test_data_full/`,处理后的用于快速验证的训练数据与测试数据放置于`./slot_train_data/``./slot_test_data/`
执行该脚本的理想输出为:
```bash
> sh download_data.sh
--2019-11-26 06:31:33-- https://fleet.bj.bcebos.com/ctr_data.tar.gz
Resolving fleet.bj.bcebos.com... 10.180.112.31
Connecting to fleet.bj.bcebos.com|10.180.112.31|:443... connected.
......@@ -100,7 +99,7 @@ def get_dataset(inputs, args)
3. 创建一个子类,继承dataset的基类,基类有多种选择,如果是多种数据类型混合,并且需要转化为数值进行预处理的,建议使用`MultiSlotDataGenerator`;若已经完成了预处理并保存为数据文件,可以直接以`string`的方式进行读取,使用`MultiSlotStringDataGenerator`,能够进一步加速。在示例代码,我们继承并实现了名为`CriteoDataset`的dataset子类,使用`MultiSlotDataGenerator`方法。
4. 继承并实现基类中的`generate_sample`函数,逐行读取数据。该函数应返回一个可以迭代的reader方法(带有yield的函数不再是一个普通的函数,而是一个生成器generator,成为了可以迭代的对象,等价于一个数组、链表、文件、字符串etc.)
5. 在这个可以迭代的函数中,如示例代码中的`def reader()`,我们定义数据读取的逻辑。例如对以行为单位的数据进行截取,转换及预处理。
6. 最后,我们需要将数据整理为特定的格式,才能够被dataset正确读取,并灌入的训练的网络中。简单来说,数据的输出顺序与我们在网络中创建的`inputs`必须是严格一一对应的,并转换为类似字典的形式。在示例代码中,我们使用`zip`的方法将参数名与数值构成的元组组成了一个list,并将其yield输出。如果展开来看,我们输出的数据形如`[('dense_feature',[value]),('C1',[value]),('C2',[value]),...,('C26',[value]),('label',[value])]`
6. 最后,我们需要将数据整理为特定的格式,才能够被dataset正确读取,并灌入的训练的网络中。简单来说,数据的输出顺序与我们在网络中创建的`inputs`必须是严格一一对应的。在示例代码中,我们将数据整理成`click:value dense_feature:value ... dense_feature:value 1:value ... 26:value`的格式。用print输出是因为我们在run.sh中将结果重定向到slot_train_data等文件中,由模型直接读取。在用户自定义使用时,可以使用`zip`的方法将参数名与数值构成的元组组成了一个list,并将其yield输出,并在config.yaml中的data_converter参数指定reader的路径。
```python
......@@ -113,11 +112,22 @@ hash_dim_ = 1000001
continuous_range_ = range(1, 14)
categorical_range_ = range(14, 40)
class CriteoDataset(dg.MultiSlotDataGenerator):
"""
DacDataset: inheritance MultiSlotDataGeneratior, Implement data reading
Help document: http://wiki.baidu.com/pages/viewpage.action?pageId=728820675
"""
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
features = line.rstrip('\n').split('\t')
dense_feature = []
sparse_feature = []
......@@ -137,11 +147,16 @@ class CriteoDataset(dg.MultiSlotDataGenerator):
for idx in categorical_range_:
feature_name.append("C" + str(idx - 13))
feature_name.append("label")
yield zip(feature_name, [dense_feature] + sparse_feature + [label])
s = "click:" + str(label[0])
for i in dense_feature:
s += " dense_feature:" + str(i)
for i in range(1, 1 + len(categorical_range_)):
s += " " + str(i) + ":" + str(sparse_feature[i - 1][0])
print(s.strip()) # add print for data preprocessing
return reader
d = CriteoDataset()
d.run_from_stdin()
```
......@@ -149,117 +164,124 @@ d.run_from_stdin()
我们可以脱离组网架构,单独验证Dataset的输出是否符合我们预期。使用命令
`cat 数据文件 | python dataset读取python文件`进行dataset代码的调试:
```bash
cat train_data/part-0 | python dataset_generator.py
cat train_data/part-0 | python get_slot_data.py
```
输出的数据格式如下:
` dense_input:size ; dense_input:value ; sparse_input:size ; sparse_input:value ; ... ; sparse_input:size ; sparse_input:value ; label:size ; label:value `
`label:value dense_input:value ... dense_input:value sparse_input:value ... sparse_input:value `
理想的输出为(截取了一个片段):
```bash
...
13 0.05 0.00663349917081 0.05 0.0 0.02159375 0.008 0.15 0.04 0.362 0.1 0.2 0.0 0.04 1 715353 1 817085 1 851010 1 833725 1 286835 1 948614 1 881652 1 507110 1 27346 1 646986 1 643076 1 200960 1 18464 1 202774 1 532679 1 729573 1 342789 1 562805 1 880474 1 984402 1 666449 1 26235 1 700326 1 452909 1 884722 1 787527 1 0
click:0 dense_feature:0.05 dense_feature:0.00663349917081 dense_feature:0.05 dense_feature:0.0 dense_feature:0.02159375 dense_feature:0.008 dense_feature:0.15 dense_feature:0.04 dense_feature:0.362 dense_feature:0.1 dense_feature:0.2 dense_feature:0.0 dense_feature:0.04 1:715353 2:817085 3:851010 4:833725 5:286835 6:948614 7:881652 8:507110 9:27346 10:646986 11:643076 12:200960 13:18464 14:202774 15:532679 16:729573 17:342789 18:562805 19:880474 20:984402 21:666449 22:26235 23:700326 24:452909 25:884722 26:787527
...
```
#
## 模型组网
### 数据输入声明
正如数据准备章节所介绍,Criteo数据集中,分为连续数据与离散(稀疏)数据,所以整体而言,CTR-DNN模型的数据输入层包括三个,分别是:`dense_input`用于输入连续数据,维度由超参数`dense_feature_dim`指定,数据类型是归一化后的浮点型数据。`sparse_input_ids`用于记录离散数据,在Criteo数据集中,共有26个slot,所以我们创建了名为`C1~C26`的26个稀疏参数输入,并设置`lod_level=1`,代表其为变长数据,数据类型为整数;最后是每条样本的`label`,代表了是否被点击,数据类型是整数,0代表负样例,1代表正样例。
在Paddle中数据输入的声明使用`paddle.fluid.data()`,会创建指定类型的占位符,数据IO会依据此定义进行数据的输入。
```python
dense_input = fluid.data(name="dense_input",
shape=[-1, args.dense_feature_dim],
dtype="float32")
sparse_input_ids = [
fluid.data(name="C" + str(i),
shape=[-1, 1],
lod_level=1,
dtype="int64") for i in range(1, 27)
]
label = fluid.data(name="label", shape=[-1, 1], dtype="int64")
inputs = [dense_input] + sparse_input_ids + [label]
```
正如数据准备章节所介绍,Criteo数据集中,分为连续数据与离散(稀疏)数据,所以整体而言,CTR-DNN模型的数据输入层包括三个,分别是:`dense_input`用于输入连续数据,维度由超参数`dense_input_dim`指定,数据类型是归一化后的浮点型数据。`sparse_inputs`用于记录离散数据,在Criteo数据集中,共有26个slot,所以我们创建了名为`1~26`的26个稀疏参数输入,数据类型为整数;最后是每条样本的`label`,代表了是否被点击,数据类型是整数,0代表负样例,1代表正样例。
### CTR-DNN模型组网
CTR-DNN模型的组网比较直观,本质是一个二分类任务,代码参考`model.py`。模型主要组成是一个`Embedding`层,`FC`层,以及相应的分类任务的loss计算和auc计算。
CTR-DNN模型的组网比较直观,本质是一个二分类任务,代码参考`model.py`。模型主要组成是一个`Embedding`层,`FC`层,以及相应的分类任务的loss计算和auc计算。
#### Embedding层
首先介绍Embedding层的搭建方式:`Embedding`层的输入是`sparse_input`shape由超参的`sparse_feature_dim``embedding_size`定义。需要特别解释的是`is_sparse`参数,当我们指定`is_sprase=True`后,计算图会将该参数视为稀疏参数,反向更新以及分布式通信时,都以稀疏的方式进行,会极大的提升运行效率,同时保证效果一致。
首先介绍Embedding层的搭建方式:`Embedding`层的输入是`sparse_input`由超参的`sparse_feature_number``sparse_feature_dimshape`定义。需要特别解释的是`is_sparse`参数,当我们指定`is_sprase=True`后,计算图会将该参数视为稀疏参数,反向更新以及分布式通信时,都以稀疏的方式进行,会极大的提升运行效率,同时保证效果一致。
各个稀疏的输入通过Embedding层后,将其合并起来,置于一个list内,以方便进行concat的操作。
```python
def embedding_layer(input):
return fluid.layers.embedding(
if self.distributed_embedding:
emb = fluid.contrib.layers.sparse_embedding(
input=input,
size=[self.sparse_feature_number, self.sparse_feature_dim],
param_attr=fluid.ParamAttr(
name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()))
else:
emb = fluid.layers.embedding(
input=input,
is_sparse=True,
size=[args.sparse_feature_dim,
args.embedding_size],
is_distributed=self.is_distributed,
size=[self.sparse_feature_number, self.sparse_feature_dim],
param_attr=fluid.ParamAttr(
name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()),
)
name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()))
emb_sum = fluid.layers.sequence_pool(input=emb, pool_type='sum')
return emb_sum
sparse_embed_seq = list(map(embedding_layer, inputs[1:-1])) # [C1~C26]
sparse_embed_seq = list(map(embedding_layer, self.sparse_inputs)) # [C1~C26]
```
#### FC层
将离散数据通过embedding查表得到的值,与连续数据的输入进行`concat`操作,合为一个整体输入,作为全链接层的原始输入。我们共设计了3层FC,每层FC的输出维度都为400,每层FC都后接一个`relu`激活函数,每层FC的初始化方式为符合正态分布的随机初始化,标准差与上一层的输出维度的平方根成反比。
将离散数据通过embedding查表得到的值,与连续数据的输入进行`concat`操作,合为一个整体输入,作为全链接层的原始输入。我们共设计了4层FC,每层FC的输出维度由超参`fc_sizes`指定,每层FC都后接一个`relu`激活函数,每层FC的初始化方式为符合正态分布的随机初始化,标准差与上一层的输出维度的平方根成反比。
```python
concated = fluid.layers.concat(sparse_embed_seq + inputs[0:1], axis=1)
fc1 = fluid.layers.fc(
input=concated,
size=400,
act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(concated.shape[1]))),
)
fc2 = fluid.layers.fc(
input=fc1,
size=400,
act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc1.shape[1]))),
)
fc3 = fluid.layers.fc(
input=fc2,
size=400,
act="relu",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc2.shape[1]))),
)
concated = fluid.layers.concat(
sparse_embed_seq + [self.dense_input], axis=1)
fcs = [concated]
hidden_layers = envs.get_global_env("hyper_parameters.fc_sizes")
for size in hidden_layers:
output = fluid.layers.fc(
input=fcs[-1],
size=size,
act='relu',
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Normal(
scale=1.0 / math.sqrt(fcs[-1].shape[1]))))
fcs.append(output)
```
#### Loss及Auc计算
- 预测的结果通过一个输出shape为2的FC层给出,该FC层的激活函数是softmax,会给出每条样本分属于正负样本的概率。
- 每条样本的损失由交叉熵给出,交叉熵的输入维度为[batch_size,2],数据类型为float,label的输入维度为[batch_size,1],数据类型为int。
- 该batch的损失`avg_cost`是各条样本的损失之和
- 我们同时还会计算预测的auc,auc的结果由`fluid.layers.auc()`给出,该层的返回值有三个,分别是全局auc: `auc_var`,当前batch的auc: `batch_auc_var`,以及auc_states: `auc_states`,auc_states包含了`batch_stat_pos, batch_stat_neg, stat_pos, stat_neg`信息。`batch_auc`我们取近20个batch的平均,由参数`slide_steps=20`指定,roc曲线的离散化的临界数值设置为4096,由`num_thresholds=2**12`指定。
- 我们同时还会计算预测的auc,auc的结果由`fluid.layers.auc()`给出,该层的返回值有三个,分别是从第一个batch累计到当前batch的全局auc: `auc`,最近几个batch的auc: `batch_auc`,以及auc_states: `_`,auc_states包含了`batch_stat_pos, batch_stat_neg, stat_pos, stat_neg`信息。`batch_auc`我们取近20个batch的平均,由参数`slide_steps=20`指定,roc曲线的离散化的临界数值设置为4096,由`num_thresholds=2**12`指定。
```
predict = fluid.layers.fc(
input=fc3,
size=2,
act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc3.shape[1]))),
)
cost = fluid.layers.cross_entropy(input=predict, label=inputs[-1])
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=inputs[-1])
auc_var, batch_auc_var, auc_states = fluid.layers.auc(
input=predict,
label=inputs[-1],
num_thresholds=2**12,
slide_steps=20)
```
input=fcs[-1],
size=2,
act="softmax",
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fcs[-1].shape[1]))))
完成上述组网后,我们最终可以通过训练拿到`avg_cost``auc`两个重要指标。
self.predict = predict
auc, batch_auc, _ = fluid.layers.auc(input=self.predict,label=self.label_input,
num_thresholds=2**12,
slide_steps=20)
cost = fluid.layers.cross_entropy(
input=self.predict, label=self.label_input)
avg_cost = fluid.layers.reduce_mean(cost)
```
完成上述组网后,我们最终可以通过训练拿到`BATCH_AUC``auc`两个重要指标。
```
PaddleRec: Runner single_cpu_infer Begin
Executor Mode: infer
processor_register begin
Running SingleInstance.
Running SingleNetwork.
Running SingleInferStartup.
Running SingleInferRunner.
load persistables from increment_dnn/3
batch: 20, BATCH_AUC: [0.75670043], AUC: [0.77490453]
batch: 40, BATCH_AUC: [0.77020144], AUC: [0.77490437]
batch: 60, BATCH_AUC: [0.77464683], AUC: [0.77490435]
batch: 80, BATCH_AUC: [0.76858989], AUC: [0.77490416]
batch: 100, BATCH_AUC: [0.75728286], AUC: [0.77490362]
batch: 120, BATCH_AUC: [0.75007016], AUC: [0.77490286]
...
batch: 720, BATCH_AUC: [0.76840144], AUC: [0.77489881]
batch: 740, BATCH_AUC: [0.76659033], AUC: [0.77489854]
batch: 760, BATCH_AUC: [0.77332639], AUC: [0.77489849]
batch: 780, BATCH_AUC: [0.78361653], AUC: [0.77489874]
Infer phase2 of epoch increment_dnn/3 done, use time: 52.7707588673, global metrics: BATCH_AUC=[0.78361653], AUC=[0.77489874]
PaddleRec Finish
```
## 流式训练(OnlineLearning)任务启动及配置流程
......@@ -387,5 +409,5 @@ auc_var, batch_auc_var, auc_states = fluid.layers.auc(
```
4. 准备好数据后, 即可按照标准的训练流程进行流式训练了
```shell
python -m paddlerec.run -m models/rerank/ctr-dnn/config.yaml
python -m paddlerec.run -m models/rank/dnn/config.yaml
```
......@@ -61,8 +61,7 @@ class CriteoDataset(dg.MultiSlotDataGenerator):
s += " dense_feature:" + str(i)
for i in range(1, 1 + len(categorical_range_)):
s += " " + str(i) + ":" + str(sparse_feature[i - 1][0])
print(s.strip())
yield None
print(s.strip()) # add print for data preprocessing
return reader
......
# GRU4REC
以下是本例的简要目录结构及说明:
```
├── data #样例数据及数据处理相关文件
├── train
├── small_train.txt # 样例训练数据
├── test
├── small_test.txt # 样例测试数据
├── convert_format.py # 数据转换脚本
├── download.py # 数据下载脚本
├── preprocess.py # 数据预处理脚本
├── text2paddle.py # paddle训练数据生成脚本
├── __init__.py
├── README.md # 文档
├── model.py #模型文件
├── config.yaml #配置文件
├── data_prepare.sh #一键数据处理脚本
├── rsc15_reader.py #reader
```
注:在阅读该示例前,建议您先了解以下内容:
[paddlerec入门教程](https://github.com/PaddlePaddle/PaddleRec/blob/master/README.md)
---
## 内容
- [模型简介](#模型简介)
- [数据准备](#数据准备)
- [运行环境](#运行环境)
- [快速开始](#快速开始)
- [论文复现](#论文复现)
- [进阶使用](#进阶使用)
- [FAQ](#FAQ)
## 模型简介
GRU4REC模型的介绍可以参阅论文[Session-based Recommendations with Recurrent Neural Networks](https://arxiv.org/abs/1511.06939)
论文的贡献在于首次将RNN(GRU)运用于session-based推荐,相比传统的KNN和矩阵分解,效果有明显的提升。
论文的核心思想是在一个session中,用户点击一系列item的行为看做一个序列,用来训练RNN模型。预测阶段,给定已知的点击序列作为输入,预测下一个可能点击的item。
session-based推荐应用场景非常广泛,比如用户的商品浏览、新闻点击、地点签到等序列数据。
本模型配置默认使用demo数据集,若进行精度验证,请参考[论文复现](#论文复现)部分。
本项目支持功能
训练:单机CPU、单机单卡GPU、本地模拟参数服务器训练、增量训练,配置请参考 [启动训练](https://github.com/PaddlePaddle/PaddleRec/blob/master/doc/train.md)
预测:单机CPU、单机单卡GPU;配置请参考[PaddleRec 离线预测](https://github.com/PaddlePaddle/PaddleRec/blob/master/doc/predict.md)
## 数据处理
本示例中数据处理共包含三步:
- Step1: 原始数据数据集下载
```
cd data/
python download.py
```
- Step2: 数据预处理及格式转换。
1. 以session_id为key合并原始数据集,得到每个session的日期,及顺序点击列表。
2. 过滤掉长度为1的session;过滤掉点击次数小于5的items。
3. 训练集、测试集划分。原始数据集里最新日期七天内的作为训练集,更早之前的数据作为测试集。
```
python preprocess.py
python convert_format.py
```
这一步之后,会在data/目录下得到两个文件,rsc15_train_tr_paddle.txt为原始训练文件,rsc15_test_paddle.txt为原始测试文件。格式如下所示:
```
214536502 214536500 214536506 214577561
214662742 214662742 214825110 214757390 214757407 214551617
214716935 214774687 214832672
214836765 214706482
214701242 214826623
214826835 214826715
214838855 214838855
214576500 214576500 214576500
214821275 214821275 214821371 214821371 214821371 214717089 214563337 214706462 214717436 214743335 214826837 214819762
214717867 21471786
```
- Step3: 生成字典并整理数据路径。这一步会根据训练和测试文件生成字典和对应的paddle输入文件,并将训练文件统一放在data/all_train目录下,测试文件统一放在data/all_test目录下。
```
mkdir raw_train_data && mkdir raw_test_data
mv rsc15_train_tr_paddle.txt raw_train_data/ && mv rsc15_test_paddle.txt raw_test_data/
mkdir all_train && mkdir all_test
python text2paddle.py raw_train_data/ raw_test_data/ all_train all_test vocab.txt
```
方便起见,我们提供了一键式数据生成脚本:
```
sh data_prepare.sh
```
## 运行环境
PaddlePaddle>=1.7.2
python 2.7/3.5/3.6/3.7
PaddleRec >=0.1
os : windows/linux/macos
## 快速开始
### 单机训练
在config.yaml文件中设置好设备,epochs等。
```
runner:
- name: cpu_train_runner
class: train
device: cpu # gpu
epochs: 10
save_checkpoint_interval: 1
save_inference_interval: 1
save_checkpoint_path: "increment_gru4rec"
save_inference_path: "inference_gru4rec"
save_inference_feed_varnames: ["src_wordseq", "dst_wordseq"] # feed vars of save inference
save_inference_fetch_varnames: ["mean_0.tmp_0", "top_k_0.tmp_0"]
print_interval: 10
phases: [train]
```
### 单机预测
在config.yaml文件中设置好设备,epochs等。
```
- name: cpu_infer_runner
class: infer
init_model_path: "increment_gru4rec"
device: cpu # gpu
phases: [infer]
```
### 运行
```
python -m paddlerec.run -m paddlerec.models.recall.gru4rec
```
### 结果展示
样例数据训练结果展示:
```
Running SingleStartup.
Running SingleRunner.
2020-09-22 03:31:18,167-INFO: [Train], epoch: 0, batch: 10, time_each_interval: 4.34s, RecallCnt: [1669.], cost: [8.366313], InsCnt: [16228.], Acc(Recall@20): [0.10284693]
2020-09-22 03:31:21,982-INFO: [Train], epoch: 0, batch: 20, time_each_interval: 3.82s, RecallCnt: [3168.], cost: [8.170701], InsCnt: [31943.], Acc(Recall@20): [0.09917666]
2020-09-22 03:31:25,797-INFO: [Train], epoch: 0, batch: 30, time_each_interval: 3.81s, RecallCnt: [4855.], cost: [8.017181], InsCnt: [47892.], Acc(Recall@20): [0.10137393]
...
epoch 0 done, use time: 6003.78719687, global metrics: cost=[4.4394927], InsCnt=23622448.0 RecallCnt=14547467.0 Acc(Recall@20)=0.6158323218660487
2020-09-22 05:11:17,761-INFO: save epoch_id:0 model into: "inference_gru4rec/0"
...
epoch 9 done, use time: 6009.97707605, global metrics: cost=[4.069373], InsCnt=236237470.0 RecallCnt=162838200.0 Acc(Recall@20)=0.6892988086157644
2020-09-22 20:17:11,358-INFO: save epoch_id:9 model into: "inference_gru4rec/9"
PaddleRec Finish
```
样例数据预测结果展示:
```
Running SingleInferStartup.
Running SingleInferRunner.
load persistables from increment_gru4rec/9
2020-09-23 03:46:21,081-INFO: [Infer] batch: 20, time_each_interval: 3.68s, RecallCnt: [24875.], InsCnt: [35581.], Acc(Recall@20): [0.6991091]
Infer infer of epoch 9 done, use time: 5.25408315659, global metrics: InsCnt=52551.0 RecallCnt=36720.0 Acc(Recall@20)=0.698749785922247
...
Infer infer of epoch 0 done, use time: 5.20699501038, global metrics: InsCnt=52551.0 RecallCnt=33664.0 Acc(Recall@20)=0.6405967536298073
PaddleRec Finish
```
## 论文复现
用原论文的完整数据复现论文效果需要在config.yaml修改超参:
- batch_size: 修改config.yaml中dataset_train数据集的batch_size为500。
- epochs: 修改config.yaml中runner的epochs为10。
- 数据源:修改config.yaml中dataset_train数据集的data_path为"{workspace}/data/all_train",dataset_test数据集的data_path为"{workspace}/data/all_test"。
使用gpu训练10轮 测试结果为
epoch | 测试recall@20 | 速度(s)
-- | -- | --
1 | 0.6406 | 6003
2 | 0.6727 | 6007
3 | 0.6831 | 6108
4 | 0.6885 | 6025
5 | 0.6913 | 6019
6 | 0.6931 | 6011
7 | 0.6952 | 6015
8 | 0.6968 | 6076
9 | 0.6972 | 6076
10 | 0.6987| 6009
修改后运行方案:修改config.yaml中的'workspace'为config.yaml的目录位置,执行
```
python -m paddlerec.run -m /home/your/dir/config.yaml #调试模式 直接指定本地config的绝对路径
```
## 进阶使用
## FAQ
......@@ -16,18 +16,19 @@ workspace: "models/recall/gru4rec"
dataset:
- name: dataset_train
batch_size: 5
type: QueueDataset
batch_size: 500
type: DataLoader # QueueDataset
data_path: "{workspace}/data/train"
data_converter: "{workspace}/rsc15_reader.py"
- name: dataset_infer
batch_size: 5
type: QueueDataset
batch_size: 500
type: DataLoader #QueueDataset
data_path: "{workspace}/data/test"
data_converter: "{workspace}/rsc15_reader.py"
hyper_parameters:
vocab_size: 1000
recall_k: 20
vocab_size: 37483
hid_size: 100
emb_lr_x: 10.0
gru_lr_x: 1.0
......@@ -40,30 +41,34 @@ hyper_parameters:
strategy: async
#use infer_runner mode and modify 'phase' below if infer
mode: train_runner
mode: [cpu_train_runner, cpu_infer_runner]
#mode: infer_runner
runner:
- name: train_runner
- name: cpu_train_runner
class: train
device: cpu
epochs: 3
save_checkpoint_interval: 2
save_inference_interval: 4
save_checkpoint_path: "increment"
save_inference_path: "inference"
epochs: 10
save_checkpoint_interval: 1
save_inference_interval: 1
save_checkpoint_path: "increment_gru4rec"
save_inference_path: "inference_gru4rec"
save_inference_feed_varnames: ["src_wordseq", "dst_wordseq"] # feed vars of save inference
save_inference_fetch_varnames: ["mean_0.tmp_0", "top_k_0.tmp_0"]
print_interval: 10
- name: infer_runner
phases: [train]
- name: cpu_infer_runner
class: infer
init_model_path: "increment/0"
init_model_path: "increment_gru4rec"
device: cpu
phases: [infer]
phase:
- name: train
model: "{workspace}/model.py"
dataset_name: dataset_train
thread_num: 1
#- name: infer
# model: "{workspace}/model.py"
# dataset_name: dataset_infer
# thread_num: 1
- name: infer
model: "{workspace}/model.py"
dataset_name: dataset_infer
thread_num: 1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import codecs
def convert_format(input, output):
with codecs.open(input, "r", encoding='utf-8') as rf:
with codecs.open(output, "w", encoding='utf-8') as wf:
last_sess = -1
sign = 1
i = 0
for l in rf:
i = i + 1
if i == 1:
continue
if (i % 1000000 == 1):
print(i)
tokens = l.strip().split()
if (int(tokens[0]) != last_sess):
if (sign):
sign = 0
wf.write(tokens[1] + " ")
else:
wf.write("\n" + tokens[1] + " ")
last_sess = int(tokens[0])
else:
wf.write(tokens[1] + " ")
input = "rsc15_train_tr.txt"
output = "rsc15_train_tr_paddle.txt"
input2 = "rsc15_test.txt"
output2 = "rsc15_test_paddle.txt"
convert_format(input, output)
convert_format(input2, output2)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import requests
import sys
import time
import os
lasttime = time.time()
FLUSH_INTERVAL = 0.1
def progress(str, end=False):
global lasttime
if end:
str += "\n"
lasttime = 0
if time.time() - lasttime >= FLUSH_INTERVAL:
sys.stdout.write("\r%s" % str)
lasttime = time.time()
sys.stdout.flush()
def _download_file(url, savepath, print_progress):
r = requests.get(url, stream=True)
total_length = r.headers.get('content-length')
if total_length is None:
with open(savepath, 'wb') as f:
shutil.copyfileobj(r.raw, f)
else:
with open(savepath, 'wb') as f:
dl = 0
total_length = int(total_length)
starttime = time.time()
if print_progress:
print("Downloading %s" % os.path.basename(savepath))
for data in r.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
if print_progress:
done = int(50 * dl / total_length)
progress("[%-50s] %.2f%%" %
('=' * done, float(100 * dl) / total_length))
if print_progress:
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
_download_file("https://paddlerec.bj.bcebos.com/gnn%2Fyoochoose-clicks.dat",
"./yoochoose-clicks.dat", True)
# -*- coding: utf-8 -*-
"""
Created on Fri Jun 25 16:20:12 2015
@author: Balázs Hidasi
"""
import numpy as np
import pandas as pd
import datetime as dt
import time
PATH_TO_ORIGINAL_DATA = './'
PATH_TO_PROCESSED_DATA = './'
data = pd.read_csv(
PATH_TO_ORIGINAL_DATA + 'yoochoose-clicks.dat',
sep=',',
header=0,
usecols=[0, 1, 2],
dtype={0: np.int32,
1: str,
2: np.int64})
data.columns = ['session_id', 'timestamp', 'item_id']
data['Time'] = data.timestamp.apply(lambda x: time.mktime(dt.datetime.strptime(x, '%Y-%m-%dT%H:%M:%S.%fZ').timetuple())) #This is not UTC. It does not really matter.
del (data['timestamp'])
session_lengths = data.groupby('session_id').size()
data = data[np.in1d(data.session_id, session_lengths[session_lengths > 1]
.index)]
item_supports = data.groupby('item_id').size()
data = data[np.in1d(data.item_id, item_supports[item_supports >= 5].index)]
session_lengths = data.groupby('session_id').size()
data = data[np.in1d(data.session_id, session_lengths[session_lengths >= 2]
.index)]
tmax = data.Time.max()
session_max_times = data.groupby('session_id').Time.max()
session_train = session_max_times[session_max_times < tmax - 86400].index
session_test = session_max_times[session_max_times >= tmax - 86400].index
train = data[np.in1d(data.session_id, session_train)]
test = data[np.in1d(data.session_id, session_test)]
test = test[np.in1d(test.item_id, train.item_id)]
tslength = test.groupby('session_id').size()
test = test[np.in1d(test.session_id, tslength[tslength >= 2].index)]
print('Full train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(
len(train), train.session_id.nunique(), train.item_id.nunique()))
train.to_csv(
PATH_TO_PROCESSED_DATA + 'rsc15_train_full.txt', sep='\t', index=False)
print('Test set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(
len(test), test.session_id.nunique(), test.item_id.nunique()))
test.to_csv(PATH_TO_PROCESSED_DATA + 'rsc15_test.txt', sep='\t', index=False)
tmax = train.Time.max()
session_max_times = train.groupby('session_id').Time.max()
session_train = session_max_times[session_max_times < tmax - 86400].index
session_valid = session_max_times[session_max_times >= tmax - 86400].index
train_tr = train[np.in1d(train.session_id, session_train)]
valid = train[np.in1d(train.session_id, session_valid)]
valid = valid[np.in1d(valid.item_id, train_tr.item_id)]
tslength = valid.groupby('session_id').size()
valid = valid[np.in1d(valid.session_id, tslength[tslength >= 2].index)]
print('Train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(
len(train_tr), train_tr.session_id.nunique(), train_tr.item_id.nunique()))
train_tr.to_csv(
PATH_TO_PROCESSED_DATA + 'rsc15_train_tr.txt', sep='\t', index=False)
print('Validation set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format(
len(valid), valid.session_id.nunique(), valid.item_id.nunique()))
valid.to_csv(
PATH_TO_PROCESSED_DATA + 'rsc15_train_valid.txt', sep='\t', index=False)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import six
import collections
import os
import sys
import io
if six.PY2:
reload(sys)
sys.setdefaultencoding('utf-8')
def word_count(input_file, word_freq=None):
"""
compute word count from corpus
"""
if word_freq is None:
word_freq = collections.defaultdict(int)
for l in input_file:
for w in l.strip().split():
word_freq[w] += 1
return word_freq
def build_dict(min_word_freq=0, train_dir="", test_dir=""):
"""
Build a word dictionary from the corpus, Keys of the dictionary are words,
and values are zero-based IDs of these words.
"""
word_freq = collections.defaultdict(int)
files = os.listdir(train_dir)
for fi in files:
with io.open(os.path.join(train_dir, fi), "r") as f:
word_freq = word_count(f, word_freq)
files = os.listdir(test_dir)
for fi in files:
with io.open(os.path.join(test_dir, fi), "r") as f:
word_freq = word_count(f, word_freq)
word_freq = [x for x in six.iteritems(word_freq) if x[1] > min_word_freq]
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted))
word_idx = dict(list(zip(words, six.moves.range(len(words)))))
return word_idx
def write_paddle(word_idx, train_dir, test_dir, output_train_dir,
output_test_dir):
files = os.listdir(train_dir)
if not os.path.exists(output_train_dir):
os.mkdir(output_train_dir)
for fi in files:
with io.open(os.path.join(train_dir, fi), "r") as f:
with io.open(os.path.join(output_train_dir, fi), "w") as wf:
for l in f:
l = l.strip().split()
l = [word_idx.get(w) for w in l]
for w in l:
wf.write(str2file(str(w) + " "))
wf.write(str2file("\n"))
files = os.listdir(test_dir)
if not os.path.exists(output_test_dir):
os.mkdir(output_test_dir)
for fi in files:
with io.open(os.path.join(test_dir, fi), "r", encoding='utf-8') as f:
with io.open(
os.path.join(output_test_dir, fi), "w",
encoding='utf-8') as wf:
for l in f:
l = l.strip().split()
l = [word_idx.get(w) for w in l]
for w in l:
wf.write(str2file(str(w) + " "))
wf.write(str2file("\n"))
def str2file(str):
if six.PY2:
return str.decode("utf-8")
else:
return str
def text2paddle(train_dir, test_dir, output_train_dir, output_test_dir,
output_vocab):
vocab = build_dict(0, train_dir, test_dir)
print("vocab size:", str(len(vocab)))
with io.open(output_vocab, "w", encoding='utf-8') as wf:
wf.write(str2file(str(len(vocab)) + "\n"))
write_paddle(vocab, train_dir, test_dir, output_train_dir, output_test_dir)
train_dir = sys.argv[1]
test_dir = sys.argv[2]
output_train_dir = sys.argv[3]
output_test_dir = sys.argv[4]
output_vocab = sys.argv[5]
text2paddle(train_dir, test_dir, output_train_dir, output_test_dir,
output_vocab)
#! /bin/bash
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
echo "begin to download data"
cd data && python download.py
python preprocess.py
echo "begin to convert data (binary -> txt)"
python convert_format.py
mkdir raw_train_data && mkdir raw_test_data
mv rsc15_train_tr_paddle.txt raw_train_data/ && mv rsc15_test_paddle.txt raw_test_data/
mkdir all_train && mkdir all_test
python text2paddle.py raw_train_data/ raw_test_data/ all_train all_test vocab.txt
......@@ -16,6 +16,7 @@ import paddle.fluid as fluid
from paddlerec.core.utils import envs
from paddlerec.core.model import ModelBase
from paddlerec.core.metrics import RecallK
class Model(ModelBase):
......@@ -81,13 +82,13 @@ class Model(ModelBase):
high=self.init_high_bound),
learning_rate=self.fc_lr_x))
cost = fluid.layers.cross_entropy(input=fc, label=dst_wordseq)
acc = fluid.layers.accuracy(
input=fc, label=dst_wordseq, k=self.recall_k)
acc = RecallK(input=fc, label=dst_wordseq, k=self.recall_k)
if is_infer:
self._infer_results['recall20'] = acc
self._infer_results['Recall@20'] = acc
return
avg_cost = fluid.layers.mean(x=cost)
self._cost = avg_cost
self._metrics["cost"] = avg_cost
self._metrics["acc"] = acc
self._metrics["Recall@20"] = acc
......@@ -348,6 +348,7 @@ def cluster_engine(args):
cluster_envs["fleet_mode"] = fleet_mode
cluster_envs["engine_role"] = "WORKER"
cluster_envs["log_dir"] = "logs"
cluster_envs["train.trainer.trainer"] = trainer
cluster_envs["train.trainer.engine"] = "cluster"
cluster_envs["train.trainer.executor_mode"] = executor_mode
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册