未验证 提交 9ecd106c 编写于 作者: T tangwei12 提交者: GitHub

Merge branch 'master' into readme_wenhui

......@@ -23,6 +23,8 @@ before_install:
- sudo pip install pylint pytest astroid isort pre-commit
- sudo pip install kiwisolver
- sudo pip install paddlepaddle==1.7.2 --ignore-installed urllib3
- sudo pip uninstall -y rarfile
- sudo pip install rarfile==3.0
- sudo python setup.py install
- |
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
......
......@@ -59,6 +59,7 @@ function _gen_mpi_config() {
-e "s#<$ OUTPUT_PATH $>#$OUTPUT_PATH#g" \
-e "s#<$ THIRDPARTY_PATH $>#$THIRDPARTY_PATH#g" \
-e "s#<$ CPU_NUM $>#$max_thread_num#g" \
-e "s#<$ USE_PYTHON3 $>#$USE_PYTHON3#g" \
-e "s#<$ FLAGS_communicator_is_sgd_optimizer $>#$FLAGS_communicator_is_sgd_optimizer#g" \
-e "s#<$ FLAGS_communicator_send_queue_size $>#$FLAGS_communicator_send_queue_size#g" \
-e "s#<$ FLAGS_communicator_thread_pool_size $>#$FLAGS_communicator_thread_pool_size#g" \
......@@ -76,6 +77,7 @@ function _gen_k8s_config() {
-e "s#<$ AFS_REMOTE_MOUNT_POINT $>#$AFS_REMOTE_MOUNT_POINT#g" \
-e "s#<$ OUTPUT_PATH $>#$OUTPUT_PATH#g" \
-e "s#<$ CPU_NUM $>#$max_thread_num#g" \
-e "s#<$ USE_PYTHON3 $>#$USE_PYTHON3#g" \
-e "s#<$ FLAGS_communicator_is_sgd_optimizer $>#$FLAGS_communicator_is_sgd_optimizer#g" \
-e "s#<$ FLAGS_communicator_send_queue_size $>#$FLAGS_communicator_send_queue_size#g" \
-e "s#<$ FLAGS_communicator_thread_pool_size $>#$FLAGS_communicator_thread_pool_size#g" \
......
......@@ -19,6 +19,7 @@ afs_local_mount_point="/root/paddlejob/workspace/env_run/afs/"
# 新k8s afs挂载帮助文档: http://wiki.baidu.com/pages/viewpage.action?pageId=906443193
PADDLE_PADDLEREC_ROLE=WORKER
use_python3=<$ USE_PYTHON3 $>
CPU_NUM=<$ CPU_NUM $>
GLOG_v=0
......
......@@ -17,6 +17,7 @@ output_path=<$ OUTPUT_PATH $>
thirdparty_path=<$ THIRDPARTY_PATH $>
PADDLE_PADDLEREC_ROLE=WORKER
use_python3=<$ USE_PYTHON3 $>
CPU_NUM=<$ CPU_NUM $>
GLOG_v=0
......
......@@ -159,23 +159,30 @@ class ClusterEnvBase(object):
self.cluster_env["PADDLE_VERSION"] = self.backend_env.get(
"config.paddle_version", "1.7.2")
# python_version
self.cluster_env["USE_PYTHON3"] = self.backend_env.get(
"config.use_python3", "0")
# communicator
max_thread_num = int(envs.get_runtime_environ("max_thread_num"))
self.cluster_env[
"FLAGS_communicator_is_sgd_optimizer"] = self.backend_env.get(
"config.communicator.FLAGS_communicator_is_sgd_optimizer", 0)
self.cluster_env[
"FLAGS_communicator_send_queue_size"] = self.backend_env.get(
"config.communicator.FLAGS_communicator_send_queue_size", 5)
"config.communicator.FLAGS_communicator_send_queue_size",
max_thread_num)
self.cluster_env[
"FLAGS_communicator_thread_pool_size"] = self.backend_env.get(
"config.communicator.FLAGS_communicator_thread_pool_size", 32)
self.cluster_env[
"FLAGS_communicator_max_merge_var_num"] = self.backend_env.get(
"config.communicator.FLAGS_communicator_max_merge_var_num", 5)
"config.communicator.FLAGS_communicator_max_merge_var_num",
max_thread_num)
self.cluster_env[
"FLAGS_communicator_max_send_grad_num_before_recv"] = self.backend_env.get(
"config.communicator.FLAGS_communicator_max_send_grad_num_before_recv",
5)
max_thread_num)
self.cluster_env["FLAGS_communicator_fake_rpc"] = self.backend_env.get(
"config.communicator.FLAGS_communicator_fake_rpc", 0)
self.cluster_env["FLAGS_rpc_retry_times"] = self.backend_env.get(
......
......@@ -69,6 +69,12 @@ dataset:
data_path: "{workspace}/data/sample_data/train"
sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
dense_slots: "dense_var:13"
phase:
- name: phase1
model: "{workspace}/model.py"
dataset_name: dataloader_train
thread_num: 1
```
分布式的训练配置可以改为:
......@@ -101,6 +107,13 @@ dataset:
data_path: "{workspace}/train_data"
sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
dense_slots: "dense_var:13"
phase:
- name: phase1
model: "{workspace}/model.py"
dataset_name: dataloader_train
# 分布式训练节点的CPU_NUM环境变量与thread_num相等,多个phase时,取最大的thread_num
thread_num: 1
```
除此之外,还需关注数据及模型加载的路径,一般而言:
......@@ -120,6 +133,8 @@ cluster_type: mpi # k8s 可选
config:
# 填写任务运行的paddle官方版本号 >= 1.7.2, 默认1.7.2
paddle_version: "1.7.2"
# 是否使用PaddleCloud运行环境下的Python3,默认使用python2
use_python3: 1
# hdfs/afs的配置信息填写
fs_name: "afs://xxx.com"
......@@ -140,11 +155,13 @@ config:
# paddle参数服务器分布式底层超参,无特殊需求不理不改
communicator:
# 使用SGD优化器时,建议设置为1
FLAGS_communicator_is_sgd_optimizer: 0
# 以下三个变量默认都等于训练时的线程数:CPU_NUM
FLAGS_communicator_send_queue_size: 5
FLAGS_communicator_thread_pool_size: 32
FLAGS_communicator_max_merge_var_num: 5
FLAGS_communicator_max_send_grad_num_before_recv: 5
FLAGS_communicator_thread_pool_size: 32
FLAGS_communicator_fake_rpc: 0
FLAGS_rpc_retry_times: 3
......@@ -175,12 +192,12 @@ submit:
# for k8s gpu
# k8s gpu 模式下,训练节点数,及每个节点上的GPU卡数
k8s_trainers: 2
k8s-cpu-cores: 4
k8s_cpu_cores: 4
k8s_gpu_card: 1
# for k8s ps-cpu
k8s_trainers: 2
k8s-cpu-cores: 4
k8s_cpu_cores: 4
k8s_ps_num: 2
k8s_ps_cores: 4
......@@ -232,7 +249,7 @@ phase:
再新增`backend.yaml`
```yaml
backend: "PaddleCloud"
cluster_type: mpi
cluster_type: mpi # k8s可选
config:
paddle_version: "1.7.2"
......@@ -317,7 +334,7 @@ phase:
```yaml
backend: "PaddleCloud"
cluster_type: k8s # k8s 可选
cluster_type: k8s # mpi 可选
config:
# 填写任务运行的paddle官方版本号 >= 1.7.2, 默认1.7.2
......@@ -357,7 +374,7 @@ submit:
# for k8s gpu
# k8s gpu 模式下,训练节点数,及每个节点上的GPU卡数
k8s_trainers: 2
k8s-cpu-cores: 4
k8s_cpu_cores: 4
k8s_gpu_card: 1
```
......@@ -399,7 +416,7 @@ phase:
再新增`backend.yaml`
```yaml
backend: "PaddleCloud"
cluster_type: k8s # k8s 可选
cluster_type: k8s # mpi 可选
config:
# 填写任务运行的paddle官方版本号 >= 1.7.2, 默认1.7.2
......@@ -439,7 +456,7 @@ submit:
# for k8s gpu
# k8s ps-cpu 模式下,训练节点数,参数服务器节点数,及每个节点上的cpu核心数及内存限制
k8s_trainers: 2
k8s-cpu-cores: 4
k8s_cpu_cores: 4
k8s_ps_num: 2
k8s_ps_cores: 4
```
......
# PaddleRec yaml配置说明
# PaddleRec config.yaml配置说明
## 全局变量
......@@ -12,31 +12,31 @@
## runner变量
| 名称 | 类型 | 取值 | 是否必须 | 作用描述 |
| :---------------------------: | :----------: | :-------------------------------------------: | :------: | :------------------------------------------------------------------: |
| name | string | 任意 | 是 | 指定runner名称 |
| 名称 | 类型 | 取值 | 是否必须 | 作用描述 |
| :---------------------------: | :----------: | :-------------------------------------------------------: | :------: | :------------------------------------------------------------------: |
| name | string | 任意 | 是 | 指定runner名称 |
| class | string | train(默认) / infer / local_cluster_train / cluster_train | 是 | 指定运行runner的类别(单机/分布式, 训练/预测) |
| device | string | cpu(默认) / gpu | 否 | 程序执行设备 |
| fleet_mode | string | ps(默认) / pslib / collective | 否 | 分布式运行模式 |
| selected_gpus | string | "0"(默认) | 否 | 程序运行GPU卡号,若以"0,1"的方式指定多卡,则会默认启用collective模式 |
| worker_num | int | 1(默认) | 否 | 参数服务器模式下worker的数量 |
| server_num | int | 1(默认) | 否 | 参数服务器模式下server的数量 |
| distribute_strategy | string | async(默认)/sync/half_async/geo | 否 | 参数服务器模式下训练模式的选择 |
| epochs | int | >= 1 | 否 | 模型训练迭代轮数 |
| phases | list[string] | 由phase name组成的list | 否 | 当前runner的训练过程列表,顺序执行 |
| init_model_path | string | 路径 | 否 | 初始化模型地址 |
| save_checkpoint_interval | int | >= 1 | 否 | Save参数的轮数间隔 |
| save_checkpoint_path | string | 路径 | 否 | Save参数的地址 |
| save_inference_interval | int | >= 1 | 否 | Save预测模型的轮数间隔 |
| save_inference_path | string | 路径 | 否 | Save预测模型的地址 |
| save_inference_feed_varnames | list[string] | 组网中指定Variable的name | 否 | 预测模型的入口变量name |
| save_inference_fetch_varnames | list[string] | 组网中指定Variable的name | 否 | 预测模型的出口变量name |
| print_interval | int | >= 1 | 否 | 训练指标打印batch间隔 |
| instance_class_path | string | 路径 | 否 | 自定义instance流程实现的地址 |
| network_class_path | string | 路径 | 否 | 自定义network流程实现的地址 |
| startup_class_path | string | 路径 | 否 | 自定义startup流程实现的地址 |
| runner_class_path | string | 路径 | 否 | 自定义runner流程实现的地址 |
| terminal_class_path | string | 路径 | 否 | 自定义terminal流程实现的地址 |
| device | string | cpu(默认) / gpu | 否 | 程序执行设备 |
| fleet_mode | string | ps(默认) / pslib / collective | 否 | 分布式运行模式 |
| selected_gpus | string | "0"(默认) | 否 | 程序运行GPU卡号,若以"0,1"的方式指定多卡,则会默认启用collective模式 |
| worker_num | int | 1(默认) | 否 | 参数服务器模式下worker的数量 |
| server_num | int | 1(默认) | 否 | 参数服务器模式下server的数量 |
| distribute_strategy | string | async(默认)/sync/half_async/geo | 否 | 参数服务器模式下训练模式的选择 |
| epochs | int | >= 1 | 否 | 模型训练迭代轮数 |
| phases | list[string] | 由phase name组成的list | 否 | 当前runner的训练过程列表,顺序执行 |
| init_model_path | string | 路径 | 否 | 初始化模型地址 |
| save_checkpoint_interval | int | >= 1 | 否 | Save参数的轮数间隔 |
| save_checkpoint_path | string | 路径 | 否 | Save参数的地址 |
| save_inference_interval | int | >= 1 | 否 | Save预测模型的轮数间隔 |
| save_inference_path | string | 路径 | 否 | Save预测模型的地址 |
| save_inference_feed_varnames | list[string] | 组网中指定Variable的name | 否 | 预测模型的入口变量name |
| save_inference_fetch_varnames | list[string] | 组网中指定Variable的name | 否 | 预测模型的出口变量name |
| print_interval | int | >= 1 | 否 | 训练指标打印batch间隔 |
| instance_class_path | string | 路径 | 否 | 自定义instance流程实现的地址 |
| network_class_path | string | 路径 | 否 | 自定义network流程实现的地址 |
| startup_class_path | string | 路径 | 否 | 自定义startup流程实现的地址 |
| runner_class_path | string | 路径 | 否 | 自定义runner流程实现的地址 |
| terminal_class_path | string | 路径 | 否 | 自定义terminal流程实现的地址 |
......@@ -70,3 +70,55 @@
| optimizer.learning_rate | float | > 0 | 否 | 指定学习率 |
| reg | float | > 0 | 否 | L2正则化参数,只在SGD下生效 |
| others | / | / | / | 由各个模型组网独立指定 |
# PaddleRec backend.yaml配置说明
## 全局变量
| 名称 | 类型 | 取值 | 是否必须 | 作用描述 |
| :----------: | :----: | :-------------: | :------: | :----------------------------------------------: |
| backend | string | paddlecloud/k8s | 是 | 使用PaddleCloud平台提交,还是在公有云K8S集群提交 |
| cluster_type | string | mpi/k8s | 是 | 指定运行的计算集群: mpi 还是 k8s |
## config
| 名称 | 类型 | 取值 | 是否必须 | 作用描述 |
| :--------------------: | :----: | :-------------------------------------: | :------: | :------------------------------------------------------------------------------------------: |
| paddle_version | string | paddle官方版本号,如1.7.2/1.8.0/1.8.3等 | 否 | 指定运行训练使用的Paddle版本,默认1.7.2 |
| use_python3 | int | 0(默认)/1 | 否 | 指定是否使用python3进行训练 |
| fs_name | string | "afs://xxx.com" | 是 | hdfs/afs集群名称所需配置 |
| fs_ugi | string | "usr,pwd" | 是 | hdfs/afs集群密钥所需配置 |
| output_path | string | "/user/your/path" | 否 | 任务输出的远程目录 |
| train_data_path | string | "/user/your/path" | 是 | mpi集群下指定训练数据路径,paddlecloud会自动将数据分片并下载到工作目录的`./train_data`文件夹 |
| test_data_path | string | "/user/your/path" | 否 | mpi集群下指定测试数据路径,会自动下载到工作目录的`./test_data`文件夹 |
| thirdparty_path | string | "/user/your/path" | 否 | mpi集群下指定thirdparty路径,会自动下载到工作目录的`./thirdparty`文件夹 |
| afs_remote_mount_point | string | "/user/your/path" | 是 | k8s集群下指定远程路径的地址,会挂载到工作目录的`./afs/下` |
### config.communicator
| 名称 | 类型 | 取值 | 是否必须 | 作用描述 |
| :----------------------------------------------: | :---: | :------------: | :------: | :----------------------------------------------------: |
| FLAGS_communicator_is_sgd_optimizer | int | 0(默认)/1 | 否 | 异步分布式训练时的多线程的梯度融合方式是否使用SGD模式 |
| FLAGS_communicator_send_queue_size | int | 线程数(默认) | 否 | 分布式训练时发送队列的大小 |
| FLAGS_communicator_max_merge_var_num | int | 线程数(默认) | 否 | 分布式训练多线程梯度融合时,线程数的配置 |
| FLAGS_communicator_max_send_grad_num_before_recv | int | 线程数(默认) | 否 | 分布式训练使用独立recv参数线程时,与send的步调配置超参 |
| FLAGS_communicator_thread_pool_size | int | 32(默认) | 否 | 分布式训练时,多线程发送参数的线程池大小 |
| FLAGS_communicator_fake_rpc | int | 0(默认)/1 | 否 | 分布式训练时,选择不进行通信 |
| FLAGS_rpc_retry_times | int | 3(默认) | 否 | 分布式训练时,GRPC的失败重试次数 |
## submit
| 名称 | 类型 | 取值 | 是否必须 | 作用描述 |
| :-----------: | :----: | :-------------------------: | :------: | :------------------------------------------------------: |
| ak | string | PaddleCloud平台提供的ak密钥 | 是 | paddlecloud用户配置 |
| sk | string | PaddleCloud平台提供的sk密钥 | 否 | paddlecloud用户配置 |
| priority | string | normal/high/very_high | 否 | 任务优先级 |
| job_name | string | 任意 | 是 | 任务名称 |
| group | string | 计算资源所在组名称 | 是 | 组名称 |
| start_cmd | string | 任意 | 是 | 启动命令,默认`python -m paddlerec.run -m ./config.yaml` |
| files | string | 任意 | 是 | 随任务提交上传的文件,给出相对或绝对路径 |
| nodes | int | >=1(默认1) | 否 | mpi集群下的节点数 |
| k8s_trainers | int | >=1(默认1) | 否 | k8s集群下worker的节点数 |
| k8s_cpu_cores | int | >=1(默认1) | 否 | k8s集群下worker的CPU核数 |
| k8s_gpu_card | int | >=1(默认1) | 否 | k8s集群下worker的GPU卡数 |
| k8s_ps_num | int | >=1(默认1) | 否 | k8s集群下server的节点数 |
| k8s_ps_cores | int | >=1(默认1) | 否 | k8s集群下server的CPU核数 |
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyrigh t(c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
workspace: "paddlerec.models.match.match-pyramid"
dataset:
- name: dataset_train
batch_size: 128
type: DataLoader
data_path: "{workspace}/data/train"
data_converter: "{workspace}/train_reader.py"
- name: dataset_infer
batch_size: 1
type: DataLoader
data_path: "{workspace}/data/test"
data_converter: "{workspace}/test_reader.py"
hyper_parameters:
optimizer:
class: adam
learning_rate: 0.001
strategy: async
emb_path: "./data/embedding.npy"
sentence_left_size: 20
sentence_right_size: 500
vocab_size: 193368
emb_size: 50
kernel_num: 8
hidden_size: 20
hidden_act: "relu"
out_size: 1
channels: 1
conv_filter: [2,10]
conv_act: "relu"
pool_size: [6,50]
pool_stride: [6,50]
pool_type: "max"
pool_padding: "VALID"
mode: [train_runner , infer_runner]
# config of each runner.
# runner is a kind of paddle training class, which wraps the train/infer process.
runner:
- name: train_runner
class: train
# num of epochs
epochs: 2
# device to run training or infer
device: cpu
save_checkpoint_interval: 1 # save model interval of epochs
save_inference_interval: 1 # save inference
save_checkpoint_path: "inference" # save checkpoint path
save_inference_path: "inference" # save inference path
save_inference_feed_varnames: [] # feed vars of save inference
save_inference_fetch_varnames: [] # fetch vars of save inference
init_model_path: "" # load model path
print_interval: 2
phases: phase_train
- name: infer_runner
class: infer
# device to run training or infer
device: cpu
print_interval: 1
init_model_path: "inference/1" # load model path
phases: phase_infer
# runner will run all the phase in each epoch
phase:
- name: phase_train
model: "{workspace}/model.py" # user-defined model
dataset_name: dataset_train # select dataset by name
thread_num: 1
- name: phase_infer
model: "{workspace}/model.py" # user-defined model
dataset_name: dataset_infer # select dataset by name
thread_num: 1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import random
# Read Word Dict and Inverse Word Dict
def read_word_dict(filename):
word_dict = {}
for line in open(filename):
line = line.strip().split()
word_dict[int(line[1])] = line[0]
print('[%s]\n\tWord dict size: %d' % (filename, len(word_dict)))
return word_dict
# Read Embedding File
def read_embedding(filename):
embed = {}
for line in open(filename):
line = line.strip().split()
embed[int(line[0])] = list(map(float, line[1:]))
print('[%s]\n\tEmbedding size: %d' % (filename, len(embed)))
return embed
# Convert Embedding Dict 2 numpy array
def convert_embed_2_numpy(embed_dict, embed=None):
for k in embed_dict:
embed[k] = np.array(embed_dict[k])
print('Generate numpy embed:', embed.shape)
return embed
# Read Data
def read_data(filename):
data = {}
for line in open(filename):
line = line.strip().split()
data[line[0]] = list(map(int, line[2:]))
print('[%s]\n\tData size: %s' % (filename, len(data)))
return data
# Read Relation Data
def read_relation(filename):
data = []
for line in open(filename):
line = line.strip().split()
data.append((int(line[0]), line[1], line[2]))
print('[%s]\n\tInstance size: %s' % (filename, len(data)))
return data
Letor07Path = "./data"
word_dict = read_word_dict(filename=os.path.join(Letor07Path, 'word_dict.txt'))
query_data = read_data(filename=os.path.join(Letor07Path, 'qid_query.txt'))
doc_data = read_data(filename=os.path.join(Letor07Path, 'docid_doc.txt'))
embed_dict = read_embedding(filename=os.path.join(Letor07Path,
'embed_wiki-pdc_d50_norm'))
_PAD_ = len(word_dict) #193367
embed_dict[_PAD_] = np.zeros((50, ), dtype=np.float32)
word_dict[_PAD_] = '[PAD]'
W_init_embed = np.float32(np.random.uniform(-0.02, 0.02, [len(word_dict), 50]))
embedding = convert_embed_2_numpy(embed_dict, embed=W_init_embed)
np.save("embedding.npy", embedding)
batch_size = 64
data1_maxlen = 20
data2_maxlen = 500
embed_size = 50
train_iters = 2500
def make_train():
rel_set = {}
pair_list = []
rel = read_relation(filename=os.path.join(Letor07Path,
'relation.train.fold1.txt'))
for label, d1, d2 in rel:
if d1 not in rel_set:
rel_set[d1] = {}
if label not in rel_set[d1]:
rel_set[d1][label] = []
rel_set[d1][label].append(d2)
for d1 in rel_set:
label_list = sorted(rel_set[d1].keys(), reverse=True)
for hidx, high_label in enumerate(label_list[:-1]):
for low_label in label_list[hidx + 1:]:
for high_d2 in rel_set[d1][high_label]:
for low_d2 in rel_set[d1][low_label]:
pair_list.append((d1, high_d2, low_d2))
print('Pair Instance Count:', len(pair_list))
f = open("./data/train/train.txt", "w")
for batch in range(800):
X1 = np.zeros((batch_size * 2, data1_maxlen), dtype=np.int32)
X2 = np.zeros((batch_size * 2, data2_maxlen), dtype=np.int32)
X1[:] = _PAD_
X2[:] = _PAD_
for i in range(batch_size):
d1, d2p, d2n = random.choice(pair_list)
d1_len = min(data1_maxlen, len(query_data[d1]))
d2p_len = min(data2_maxlen, len(doc_data[d2p]))
d2n_len = min(data2_maxlen, len(doc_data[d2n]))
X1[i, :d1_len] = query_data[d1][:d1_len]
X2[i, :d2p_len] = doc_data[d2p][:d2p_len]
X1[i + batch_size, :d1_len] = query_data[d1][:d1_len]
X2[i + batch_size, :d2n_len] = doc_data[d2n][:d2n_len]
for i in range(batch_size * 2):
q = [str(x) for x in list(X1[i])]
d = [str(x) for x in list(X2[i])]
f.write(",".join(q) + "\t" + ",".join(d) + "\n")
f.close()
def make_test():
rel = read_relation(filename=os.path.join(Letor07Path,
'relation.test.fold1.txt'))
f = open("./data/test/test.txt", "w")
for label, d1, d2 in rel:
X1 = np.zeros(data1_maxlen, dtype=np.int32)
X2 = np.zeros(data2_maxlen, dtype=np.int32)
X1[:] = _PAD_
X2[:] = _PAD_
d1_len = min(data1_maxlen, len(query_data[d1]))
d2_len = min(data2_maxlen, len(doc_data[d2]))
X1[:d1_len] = query_data[d1][:d1_len]
X2[:d2_len] = doc_data[d2][:d2_len]
q = [str(x) for x in list(X1)]
d = [str(x) for x in list(X2)]
f.write(",".join(q) + "\t" + ",".join(d) + "\t" + str(label) + "\t" +
d1 + "\n")
f.close()
make_train()
make_test()
2 9639 GX099-60-3149248
1 9639 GX028-47-6554966
1 9639 GX031-84-2802741
1 9639 GX031-86-1702683
1 9639 GX031-89-11392170
1 9639 GX035-46-10142187
1 9639 GX039-07-1333080
1 9639 GX040-05-15096071
1 9639 GX045-35-10693225
1 9639 GX045-74-6226888
1 9639 GX046-31-8871083
1 9639 GX046-56-6274894
1 9639 GX050-09-14629105
1 9639 GX097-05-12714275
1 9639 GX101-06-7768196
1 9639 GX124-50-4934142
1 9639 GX259-01-13320140
1 9639 GX259-50-8109630
1 9639 GX259-72-16176934
1 9639 GX259-98-7821925
1 9639 GX260-27-13260880
1 9639 GX260-54-6363694
1 9639 GX260-78-6999656
1 9639 GX261-04-0843988
1 9639 GX261-23-4964814
0 9639 GX021-75-7026755
0 9639 GX021-80-16449591
0 9639 GX025-40-7135810
0 9639 GX031-89-9020252
0 9639 GX037-45-0533209
0 9639 GX038-17-11223353
0 9639 GX057-07-13335832
0 9639 GX081-50-12756687
0 9639 GX124-43-2364716
0 9639 GX129-60-0000000
0 9639 GX219-07-7475581
0 9639 GX233-90-7976935
0 9639 GX267-49-2983064
0 9639 GX267-74-2413254
0 9639 GX270-05-13614294
1 9329 GX234-05-0812081
0 9329 GX000-00-0000000
0 9329 GX008-50-3899336
0 9329 GX011-75-8470249
0 9329 GX020-42-13388867
0 9329 GX024-91-8520306
0 9329 GX026-88-6087429
0 9329 GX027-22-1703847
0 9329 GX034-11-2617393
0 9329 GX036-02-7994497
0 9329 GX046-08-13858054
0 9329 GX059-85-11403109
0 9329 GX099-37-0232298
0 9329 GX099-46-11473306
0 9329 GX108-04-9589788
0 9329 GX110-50-11723940
0 9329 GX124-11-4119164
0 9329 GX149-82-15204191
0 9329 GX165-95-6198495
0 9329 GX225-56-4184936
0 9329 GX229-57-4487470
0 9329 GX230-37-4125963
0 9329 GX231-40-14574318
0 9329 GX238-44-10302536
0 9329 GX239-85-8572461
0 9329 GX244-17-10154048
0 9329 GX245-16-4169590
0 9329 GX245-46-6341859
0 9329 GX246-91-8487173
0 9329 GX262-88-13259441
0 9329 GX263-41-4135561
0 9329 GX264-07-6385713
0 9329 GX264-38-12253757
0 9329 GX264-90-15990025
0 9329 GX265-89-6212449
0 9329 GX268-41-12034794
0 9329 GX268-83-5140660
0 9329 GX270-46-0293828
0 9329 GX270-64-11852140
0 9329 GX271-10-12458597
2 9326 GX272-03-6610348
1 9326 GX011-12-0595978
0 9326 GX000-00-0000000
0 9326 GX000-38-9492606
0 9326 GX000-84-4587136
0 9326 GX002-41-5566464
0 9326 GX002-51-2615036
0 9326 GX004-56-12238694
0 9326 GX004-72-2476906
0 9326 GX008-13-1835206
0 9326 GX008-64-7705528
0 9326 GX009-87-0976731
0 9326 GX012-24-7688369
0 9326 GX012-96-8727608
0 9326 GX023-87-16736657
0 9326 GX025-21-11820239
0 9326 GX025-22-15113698
0 9326 GX025-51-13959128
0 9326 GX025-57-11414648
0 9326 GX025-64-7587631
0 9326 GX027-62-4542881
0 9326 GX031-25-4759403
0 9326 GX036-10-7902858
0 9326 GX047-04-9457544
0 9326 GX047-06-4014803
0 9326 GX048-00-15113058
0 9326 GX048-02-12975919
0 9326 GX048-78-3273874
0 9326 GX235-35-0963257
0 9326 GX235-98-3789570
0 9326 GX236-51-15473637
0 9326 GX237-96-0892713
0 9326 GX239-35-7413891
0 9326 GX239-95-0176537
0 9326 GX251-34-10377030
0 9326 GX254-19-11374782
0 9326 GX260-63-10533444
0 9326 GX265-94-14886230
0 9326 GX269-78-1500497
0 9326 GX270-59-10270517
2 8946 GX046-79-6984659
2 8946 GX148-33-1869479
2 8946 GX252-36-12638222
1 8946 GX017-47-13290921
1 8946 GX030-69-3218092
1 8946 GX034-82-4550348
1 8946 GX044-01-9283107
1 8946 GX047-98-6660623
1 8946 GX057-96-12580825
1 8946 GX059-94-12068143
1 8946 GX060-13-13600036
1 8946 GX060-74-6594973
1 8946 GX093-08-1158999
0 8946 GX000-00-0000000
0 8946 GX000-42-15811803
0 8946 GX000-81-16418910
0 8946 GX008-38-10557859
0 8946 GX011-01-10891808
0 8946 GX013-71-5708874
0 8946 GX015-72-4458924
0 8946 GX023-91-9869060
0 8946 GX027-56-6376748
0 8946 GX037-11-10829529
0 8946 GX038-55-0681330
0 8946 GX043-86-4200105
0 8946 GX047-52-3712485
0 8946 GX053-77-4836617
0 8946 GX070-62-1070063
0 8946 GX105-53-13372327
0 8946 GX218-61-6263172
0 8946 GX223-72-13625320
0 8946 GX230-68-14727182
0 8946 GX235-34-7733230
0 8946 GX251-73-0159347
0 8946 GX254-47-1098586
0 8946 GX263-76-6934681
0 8946 GX263-84-8668756
0 8946 GX264-70-14223639
0 8946 GX269-12-5910753
0 8946 GX271-93-9895614
1 9747 GX006-77-1973537
1 9747 GX244-83-8716953
1 9747 GX269-92-7189826
0 9747 GX000-00-0000000
0 9747 GX001-51-8693413
0 9747 GX003-10-2820641
0 9747 GX003-74-0557776
0 9747 GX003-79-13695689
0 9747 GX009-57-0938999
0 9747 GX009-59-8595527
0 9747 GX009-80-10629348
0 9747 GX010-37-0206372
0 9747 GX013-46-2187318
0 9747 GX014-58-4004859
0 9747 GX015-79-5393654
0 9747 GX032-50-7316370
0 9747 GX049-33-2206612
0 9747 GX050-34-0439256
0 9747 GX062-76-0914936
0 9747 GX065-73-7392661
0 9747 GX148-27-15770966
0 9747 GX155-71-0504939
0 9747 GX229-75-14750078
0 9747 GX231-01-0640962
0 9747 GX236-45-15598812
0 9747 GX247-19-9516715
0 9747 GX247-34-4277646
0 9747 GX247-63-10766287
0 9747 GX248-23-15998266
0 9747 GX249-85-9742193
0 9747 GX250-31-7671617
0 9747 GX252-56-2141580
0 9747 GX253-15-3406713
0 9747 GX264-07-15838087
0 9747 GX264-43-6543997
0 9747 GX266-18-14688076
0 9747 GX267-50-2036010
0 9747 GX268-28-0548507
0 9747 GX269-49-14171555
0 9747 GX269-63-15607386
2 9740 GX005-94-14208849
2 9740 GX008-51-5639660
2 9740 GX012-37-2342061
2 9740 GX019-75-13916532
2 9740 GX074-76-16261807
2 9740 GX077-07-2951943
2 9740 GX229-28-11068981
2 9740 GX237-80-7497206
2 9740 GX257-53-10589749
2 9740 GX258-06-0611419
2 9740 GX268-55-9791226
1 9740 GX007-62-1126118
1 9740 GX015-78-0216468
1 9740 GX038-65-1678199
1 9740 GX041-25-14803324
1 9740 GX063-71-0401425
1 9740 GX077-08-15801730
1 9740 GX098-07-2885671
1 9740 GX135-28-6485892
1 9740 GX228-85-10518518
1 9740 GX231-93-11279468
1 9740 GX234-70-15061254
1 9740 GX236-31-11149347
1 9740 GX240-68-1184464
1 9740 GX248-03-7275316
1 9740 GX253-11-9846012
1 9740 GX255-05-10638500
1 9740 GX267-73-4450097
1 9740 GX269-19-0642640
0 9740 GX001-74-5132048
0 9740 GX001-88-2603815
0 9740 GX004-83-7935833
0 9740 GX007-01-16750210
0 9740 GX040-11-5249209
0 9740 GX042-38-2886005
0 9740 GX052-20-4359789
0 9740 GX067-74-3718011
0 9740 GX077-01-13481396
0 9740 GX242-92-8868913
0 9740 GX262-74-4596688
2 8835 GX010-99-5715419
2 8835 GX049-99-2518724
0 8835 GX000-00-0000000
0 8835 GX007-91-6779497
0 8835 GX008-14-0788708
0 8835 GX008-15-13942125
0 8835 GX011-58-14336551
0 8835 GX012-79-10684001
0 8835 GX013-00-10822427
0 8835 GX013-03-5962783
0 8835 GX015-54-0251701
0 8835 GX017-36-5859317
0 8835 GX017-60-0601078
0 8835 GX027-24-16202205
0 8835 GX030-11-15814183
0 8835 GX030-76-11969233
因为 它太大了无法显示 source diff 。你可以改为 查看blob
因为 它太大了无法显示 source diff 。你可以改为 查看blob
#!/bin/bash
echo "...........load data................."
wget --no-check-certificate 'https://paddlerec.bj.bcebos.com/match_pyramid/match_pyramid_data.tar.gz'
mv ./match_pyramid_data.tar.gz ./data
rm -rf ./data/relation.test.fold1.txt ./data/realtion.train.fold1.txt
tar -xvf ./data/match_pyramid_data.tar.gz
echo "...........data process..............."
python ./data/process.py
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
def eval_MAP(pred, gt):
map_value = 0.0
r = 0.0
c = list(zip(pred, gt))
random.shuffle(c)
c = sorted(c, key=lambda x: x[0], reverse=True)
for j, (p, g) in enumerate(c):
if g != 0:
r += 1
map_value += r / (j + 1.0)
if r == 0:
return 0.0
else:
return map_value / r
filename = './data/relation.test.fold1.txt'
gt = []
qid = []
f = open(filename, "r")
f.readline()
num = 0
for line in f.readlines():
num = num + 1
line = line.strip().split()
gt.append(int(line[0]))
qid.append(line[1])
f.close()
print(num)
filename = './result.txt'
pred = []
for line in open(filename):
line = line.strip().split(",")
line[1] = line[1].split(":")
line = line[1][1].strip(" ")
line = line.strip("[")
line = line.strip("]")
pred.append(float(line))
result_dict = {}
for i in range(len(qid)):
if qid[i] not in result_dict:
result_dict[qid[i]] = []
result_dict[qid[i]].append([gt[i], pred[i]])
print(len(result_dict))
map = 0
for qid in result_dict:
gt = np.array(result_dict[qid])[:, 0]
pred = np.array(result_dict[qid])[:, 1]
map += eval_MAP(pred, gt)
map = map / len(result_dict)
print("map=", map)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import random
import numpy as np
import paddle
import paddle.fluid as fluid
from paddlerec.core.utils import envs
from paddlerec.core.model import ModelBase
class Model(ModelBase):
def __init__(self, config):
ModelBase.__init__(self, config)
def _init_hyper_parameters(self):
self.emb_path = envs.get_global_env("hyper_parameters.emb_path")
self.sentence_left_size = envs.get_global_env(
"hyper_parameters.sentence_left_size")
self.sentence_right_size = envs.get_global_env(
"hyper_parameters.sentence_right_size")
self.vocab_size = envs.get_global_env("hyper_parameters.vocab_size")
self.emb_size = envs.get_global_env("hyper_parameters.emb_size")
self.kernel_num = envs.get_global_env("hyper_parameters.kernel_num")
self.hidden_size = envs.get_global_env("hyper_parameters.hidden_size")
self.hidden_act = envs.get_global_env("hyper_parameters.hidden_act")
self.out_size = envs.get_global_env("hyper_parameters.out_size")
self.channels = envs.get_global_env("hyper_parameters.channels")
self.conv_filter = envs.get_global_env("hyper_parameters.conv_filter")
self.conv_act = envs.get_global_env("hyper_parameters.conv_act")
self.pool_size = envs.get_global_env("hyper_parameters.pool_size")
self.pool_stride = envs.get_global_env("hyper_parameters.pool_stride")
self.pool_type = envs.get_global_env("hyper_parameters.pool_type")
self.pool_padding = envs.get_global_env(
"hyper_parameters.pool_padding")
def input_data(self, is_infer=False, **kwargs):
sentence_left = fluid.data(
name="sentence_left",
shape=[-1, self.sentence_left_size, 1],
dtype='int64',
lod_level=0)
sentence_right = fluid.data(
name="sentence_right",
shape=[-1, self.sentence_right_size, 1],
dtype='int64',
lod_level=0)
return [sentence_left, sentence_right]
def embedding_layer(self, input):
"""
embedding layer
"""
if os.path.isfile(self.emb_path):
embedding_array = np.load(self.emb_path)
emb = fluid.layers.embedding(
input=input,
size=[self.vocab_size, self.emb_size],
padding_idx=0,
param_attr=fluid.ParamAttr(
name="word_embedding",
initializer=fluid.initializer.NumpyArrayInitializer(
embedding_array)))
else:
emb = fluid.layers.embedding(
input=input,
size=[self.vocab_size, self.emb_size],
padding_idx=0,
param_attr=fluid.ParamAttr(
name="word_embedding",
initializer=fluid.initializer.Xavier()))
return emb
def conv_pool_layer(self, input):
"""
convolution and pool layer
"""
# data format NCHW
# same padding
conv = fluid.layers.conv2d(
input=input,
num_filters=self.kernel_num,
stride=1,
padding="SAME",
filter_size=self.conv_filter,
act=self.conv_act)
pool = fluid.layers.pool2d(
input=conv,
pool_size=self.pool_size,
pool_stride=self.pool_stride,
pool_type=self.pool_type,
pool_padding=self.pool_padding)
return pool
def net(self, inputs, is_infer=False):
left_emb = self.embedding_layer(inputs[0])
right_emb = self.embedding_layer(inputs[1])
cross = fluid.layers.matmul(left_emb, right_emb, transpose_y=True)
cross = fluid.layers.reshape(cross,
[-1, 1, cross.shape[1], cross.shape[2]])
conv_pool = self.conv_pool_layer(input=cross)
relu_hid = fluid.layers.fc(input=conv_pool,
size=self.hidden_size,
act=self.hidden_act)
prediction = fluid.layers.fc(
input=relu_hid,
size=self.out_size, )
if is_infer:
self._infer_results["prediction"] = prediction
return
pos = fluid.layers.slice(
prediction, axes=[0, 1], starts=[0, 0], ends=[64, 1])
neg = fluid.layers.slice(
prediction, axes=[0, 1], starts=[64, 0], ends=[128, 1])
loss_part1 = fluid.layers.elementwise_sub(
fluid.layers.fill_constant(
shape=[64, 1], value=1.0, dtype='float32'),
pos)
loss_part2 = fluid.layers.elementwise_add(loss_part1, neg)
loss_part3 = fluid.layers.elementwise_max(
fluid.layers.fill_constant(
shape=[64, 1], value=0.0, dtype='float32'),
loss_part2)
avg_cost = fluid.layers.mean(loss_part3)
self._cost = avg_cost
# match-pyramid文本匹配模型
## 介绍
在许多自然语言处理任务中,匹配两个文本是一个基本问题。一种有效的方法是从单词,短语和句子中提取有意义的匹配模式以产生匹配分数。受卷积神经网络在图像识别中的成功启发,神经元可以根据提取的基本视觉模式(例如定向的边角和边角)捕获许多复杂的模式,所以我们尝试将文本匹配建模为图像识别问题。本模型对齐原作者庞亮开源的tensorflow代码:https://github.com/pl8787/MatchPyramid-TensorFlow/blob/master/model/model_mp.py, 实现了下述论文中提出的Match-Pyramid模型:
```text
@inproceedings{Pang L , Lan Y , Guo J , et al. Text Matching as Image Recognition[J]. 2016.,
title={Text Matching as Image Recognition},
author={Liang Pang, Yanyan Lan, Jiafeng Guo, Jun Xu, Shengxian Wan, Xueqi Cheng},
year={2016}
}
```
## 数据准备
训练及测试数据集选用Letor07数据集和 embed_wiki-pdc_d50_norm 词向量初始化embedding层。
该数据集包括:
1.词典文件:我们将每个单词映射得到一个唯一的编号wid,并将此映射保存在单词词典文件中。例如:word_dict.txt
2.语料库文件:我们使用字符串标识符的值表示一个句子的编号。第二个数字表示句子的长度。例如:qid_query.txt和docid_doc.txt
3.关系文件:关系文件被用来存储两个句子之间的关系,如query 和document之间的关系。例如:relation.train.fold1.txt, relation.test.fold1.txt
4.嵌入层文件:我们将预训练的词向量存储在嵌入文件中。例如:embed_wiki-pdc_d50_norm
## 数据下载和预处理
本文提供了数据集的下载以及一键生成训练和测试数据的预处理脚本,您可以直接一键运行:bash data_process.sh
执行该脚本,会从国内源的服务器上下载Letor07数据集,删除掉data文件夹中原有的relation.test.fold1.txt和relation.train.fold1.txt,并将完整的数据集解压到data文件夹。随后运行 process.py 将全量训练数据放置于`./data/train`,全量测试数据放置于`./data/test`。并生成用于初始化embedding层的embedding.npy文件
执行该脚本的理想输出为:
```
bash data_process.sh
...........load data...............
--2020-07-13 13:24:50-- https://paddlerec.bj.bcebos.com/match_pyramid/match_pyramid_data.tar.gz
Resolving paddlerec.bj.bcebos.com... 10.70.0.165
Connecting to paddlerec.bj.bcebos.com|10.70.0.165|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 214449643 (205M) [application/x-gzip]
Saving to: “match_pyramid_data.tar.gz”
100%[==========================================================================================================>] 214,449,643 114M/s in 1.8s
2020-07-13 13:24:52 (114 MB/s) - “match_pyramid_data.tar.gz” saved [214449643/214449643]
data/
data/relation.test.fold1.txt
data/relation.test.fold2.txt
data/relation.test.fold3.txt
data/relation.test.fold4.txt
data/relation.test.fold5.txt
data/relation.train.fold1.txt
data/relation.train.fold2.txt
data/relation.train.fold3.txt
data/relation.train.fold4.txt
data/relation.train.fold5.txt
data/relation.txt
data/docid_doc.txt
data/qid_query.txt
data/word_dict.txt
data/embed_wiki-pdc_d50_norm
...........data process...............
[./data/word_dict.txt]
Word dict size: 193367
[./data/qid_query.txt]
Data size: 1692
[./data/docid_doc.txt]
Data size: 65323
[./data/embed_wiki-pdc_d50_norm]
Embedding size: 109282
('Generate numpy embed:', (193368, 50))
[./data/relation.train.fold1.txt]
Instance size: 47828
('Pair Instance Count:', 325439)
[./data/relation.test.fold1.txt]
Instance size: 13652
```
## 一键训练并测试评估
本文提供了一键执行训练,测试和评估的脚本,您可以直接一键运行:bash run.sh
执行该脚本后,会执行python -m paddlerec.run -m ./config.yaml 命令开始训练并测试模型,将测试的结果保存到result.txt文件,最后通过执行eval.py进行评估得到数据的map指标
执行该脚本的理想输出为:
```
..............test.................
13651
336
('map=', 0.420878322843591)
```
## 每个文件的作用
paddlerec可以:
通过config.yaml规定模型的参数
通过model.py规定模型的组网
使用train_reader.py读取训练集中的数据
使用test_reader.py读取测试集中的数据。
本文额外提供:
data_process.sh用来一键处理数据
run.sh用来一键启动训练,直接得出测试结果
eval.py通过保存的测试结果,计算map指标
如需详细了解paddlerec的使用方法请参考https://github.com/PaddlePaddle/PaddleRec/blob/master/README_CN.md 页面下方的教程。
#!/bin/bash
echo "................run................."
python -m paddlerec.run -m ./config.yaml >result1.txt
grep -A1 "prediction" ./result1.txt >./result.txt
rm -f result1.txt
python eval.py
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddlerec.core.reader import ReaderBase
class Reader(ReaderBase):
def init(self):
pass
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
features = line.strip('\n').split('\t')
doc1 = [int(word_id) for word_id in features[0].split(",")]
doc2 = [int(word_id) for word_id in features[1].split(",")]
features_name = ["doc1", "doc2"]
yield zip(features_name, [doc1] + [doc2])
return reader
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddlerec.core.reader import ReaderBase
class Reader(ReaderBase):
def init(self):
pass
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
features = line.strip('\n').split('\t')
doc1 = [int(word_id) for word_id in features[0].split(",")]
doc2 = [int(word_id) for word_id in features[1].split(",")]
features_name = ["doc1", "doc2"]
yield zip(features_name, [doc1] + [doc2])
return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册