提交 63ea5c78 编写于 作者: Y yinhaofeng

merge

branch 'master' of https://github.com/PaddlePaddle/PaddleRec into multiview-simnet
......@@ -19,6 +19,7 @@ afs_local_mount_point="/root/paddlejob/workspace/env_run/afs/"
# 新k8s afs挂载帮助文档: http://wiki.baidu.com/pages/viewpage.action?pageId=906443193
PADDLE_PADDLEREC_ROLE=WORKER
PADDLEREC_CLUSTER_TYPE=K8S
use_python3=<$ USE_PYTHON3 $>
CPU_NUM=<$ CPU_NUM $>
GLOG_v=0
......
......@@ -17,6 +17,7 @@ output_path=<$ OUTPUT_PATH $>
thirdparty_path=<$ THIRDPARTY_PATH $>
PADDLE_PADDLEREC_ROLE=WORKER
PADDLEREC_CLUSTER_TYPE=MPI
use_python3=<$ USE_PYTHON3 $>
CPU_NUM=<$ CPU_NUM $>
GLOG_v=0
......
......@@ -107,6 +107,7 @@ class Trainer(object):
self.device = Device.GPU
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
self._place = fluid.CUDAPlace(gpu_id)
print("PaddleRec run on device GPU: {}".format(gpu_id))
self._exe = fluid.Executor(self._place)
elif device == "CPU":
self.device = Device.CPU
......@@ -146,6 +147,7 @@ class Trainer(object):
elif engine.upper() == "CLUSTER":
self.engine = EngineMode.CLUSTER
self.is_fleet = True
self.which_cluster_type()
else:
raise ValueError("Not Support Engine {}".format(engine))
self._context["is_fleet"] = self.is_fleet
......@@ -165,6 +167,14 @@ class Trainer(object):
self._context["is_pslib"] = (fleet_mode.upper() == "PSLIB")
self._context["fleet_mode"] = fleet_mode
def which_cluster_type(self):
cluster_type = os.getenv("PADDLEREC_CLUSTER_TYPE", "MPI")
print("PADDLEREC_CLUSTER_TYPE: {}".format(cluster_type))
if cluster_type and cluster_type.upper() == "K8S":
self._context["cluster_type"] = "K8S"
else:
self._context["cluster_type"] = "MPI"
def which_executor_mode(self):
executor_mode = envs.get_runtime_environ("train.trainer.executor_mode")
if executor_mode.upper() not in ["TRAIN", "INFER"]:
......
......@@ -21,7 +21,7 @@ from paddlerec.core.utils import envs
from paddlerec.core.utils import dataloader_instance
from paddlerec.core.reader import SlotReader
from paddlerec.core.trainer import EngineMode
from paddlerec.core.utils.util import split_files
from paddlerec.core.utils.util import split_files, check_filelist
__all__ = ["DatasetBase", "DataLoader", "QueueDataset"]
......@@ -119,14 +119,30 @@ class QueueDataset(DatasetBase):
dataset.set_pipe_command(pipe_cmd)
train_data_path = envs.get_global_env(name + "data_path")
file_list = [
os.path.join(train_data_path, x)
for x in os.listdir(train_data_path)
]
hidden_file_list, file_list = check_filelist(
hidden_file_list=[],
data_file_list=[],
train_data_path=train_data_path)
if (hidden_file_list is not None):
print(
"Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}".
format(hidden_file_list))
file_list.sort()
need_split_files = False
if context["engine"] == EngineMode.LOCAL_CLUSTER:
# for local cluster: split files for multi process
need_split_files = True
elif context["engine"] == EngineMode.CLUSTER and context[
"cluster_type"] == "K8S":
# for k8s mount afs, split files for every node
need_split_files = True
if need_split_files:
file_list = split_files(file_list, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("File_list: {}".format(file_list))
dataset.set_filelist(file_list)
for model_dict in context["phases"]:
if model_dict["dataset_name"] == dataset_name:
......
......@@ -16,6 +16,7 @@ from __future__ import print_function
import os
import time
import warnings
import numpy as np
import paddle.fluid as fluid
......@@ -284,6 +285,7 @@ class RunnerBase(object):
return (epoch_id + 1) % epoch_interval == 0
def save_inference_model():
# get global env
name = "runner." + context["runner_name"] + "."
save_interval = int(
envs.get_global_env(name + "save_inference_interval", -1))
......@@ -296,16 +298,42 @@ class RunnerBase(object):
if feed_varnames is None or fetch_varnames is None or feed_varnames == "" or fetch_varnames == "" or \
len(feed_varnames) == 0 or len(fetch_varnames) == 0:
return
fetch_vars = [
fluid.default_main_program().global_block().vars[varname]
for varname in fetch_varnames
]
# check feed var exist
for var_name in feed_varnames:
if var_name not in fluid.default_main_program().global_block(
).vars:
raise ValueError(
"Feed variable: {} not in default_main_program, global block has follow vars: {}".
format(var_name,
fluid.default_main_program().global_block()
.vars.keys()))
# check fetch var exist
fetch_vars = []
for var_name in fetch_varnames:
if var_name not in fluid.default_main_program().global_block(
).vars:
raise ValueError(
"Fetch variable: {} not in default_main_program, global block has follow vars: {}".
format(var_name,
fluid.default_main_program().global_block()
.vars.keys()))
else:
fetch_vars.append(fluid.default_main_program()
.global_block().vars[var_name])
dirname = envs.get_global_env(name + "save_inference_path", None)
assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
warnings.warn(
"Save inference model in cluster training is not recommended! Using save checkpoint instead.",
category=UserWarning,
stacklevel=2)
if context["fleet"].worker_index() == 0:
context["fleet"].save_inference_model(
context["exe"], dirname, feed_varnames, fetch_vars)
else:
......@@ -323,6 +351,7 @@ class RunnerBase(object):
return
dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
if context["fleet"].worker_index() == 0:
context["fleet"].save_persistables(context["exe"], dirname)
else:
fluid.io.save_persistables(context["exe"], dirname)
......
......@@ -19,7 +19,7 @@ from paddlerec.core.utils.envs import get_global_env
from paddlerec.core.utils.envs import get_runtime_environ
from paddlerec.core.reader import SlotReader
from paddlerec.core.trainer import EngineMode
from paddlerec.core.utils.util import split_files
from paddlerec.core.utils.util import split_files, check_filelist
def dataloader_by_name(readerclass,
......@@ -38,11 +38,27 @@ def dataloader_by_name(readerclass,
assert package_base is not None
data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
hidden_file_list, files = check_filelist(
hidden_file_list=[], data_file_list=[], train_data_path=data_path)
if (hidden_file_list is not None):
print(
"Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}".
format(hidden_file_list))
files.sort()
need_split_files = False
if context["engine"] == EngineMode.LOCAL_CLUSTER:
# for local cluster: split files for multi process
need_split_files = True
elif context["engine"] == EngineMode.CLUSTER and context[
"cluster_type"] == "K8S":
# for k8s mount mode, split files for every node
need_split_files = True
print("need_split_files: {}".format(need_split_files))
if need_split_files:
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("file_list : {}".format(files))
reader = reader_class(yaml_file)
reader.init()
......@@ -80,11 +96,27 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file, context):
assert package_base is not None
data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
hidden_file_list, files = check_filelist(
hidden_file_list=[], data_file_list=[], train_data_path=data_path)
if (hidden_file_list is not None):
print(
"Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}".
format(hidden_file_list))
files.sort()
need_split_files = False
if context["engine"] == EngineMode.LOCAL_CLUSTER:
# for local cluster: split files for multi process
need_split_files = True
elif context["engine"] == EngineMode.CLUSTER and context[
"cluster_type"] == "K8S":
# for k8s mount mode, split files for every node
need_split_files = True
if need_split_files:
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("file_list: {}".format(files))
sparse = get_global_env(name + "sparse_slots", "#")
if sparse == "":
......@@ -134,11 +166,27 @@ def slotdataloader(readerclass, train, yaml_file, context):
assert package_base is not None
data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
hidden_file_list, files = check_filelist(
hidden_file_list=[], data_file_list=[], train_data_path=data_path)
if (hidden_file_list is not None):
print(
"Warning:please make sure there are no hidden files in the dataset folder and check these hidden files:{}".
format(hidden_file_list))
files.sort()
need_split_files = False
if context["engine"] == EngineMode.LOCAL_CLUSTER:
# for local cluster: split files for multi process
need_split_files = True
elif context["engine"] == EngineMode.CLUSTER and context[
"cluster_type"] == "K8S":
# for k8s mount mode, split files for every node
need_split_files = True
if need_split_files:
files = split_files(files, context["fleet"].worker_index(),
context["fleet"].worker_num())
print("file_list: {}".format(files))
sparse = get_global_env("sparse_slots", "#", namespace)
if sparse == "":
......
......@@ -201,6 +201,28 @@ def split_files(files, trainer_id, trainers):
return trainer_files[trainer_id]
def check_filelist(hidden_file_list, data_file_list, train_data_path):
for root, dirs, files in os.walk(train_data_path):
if (files == None and dirs == None):
return None, None
else:
# use files and dirs
for file_name in files:
file_path = os.path.join(train_data_path, file_name)
if file_name[0] == '.':
hidden_file_list.append(file_path)
else:
data_file_list.append(file_path)
for dirs_name in dirs:
dirs_path = os.path.join(train_data_path, dirs_name)
if dirs_name[0] == '.':
hidden_file_list.append(dirs_path)
else:
#train_data_path = os.path.join(train_data_path, dirs_name)
check_filelist(hidden_file_list, data_file_list, dirs_path)
return hidden_file_list, data_file_list
class CostPrinter(object):
"""
For count cost time && print cost log
......
......@@ -16,9 +16,14 @@ workspace: "models/contentunderstanding/classification"
dataset:
- name: data1
batch_size: 5
batch_size: 10
type: DataLoader
data_path: "{workspace}/data/train_data"
data_path: "{workspace}/data/train"
data_converter: "{workspace}/reader.py"
- name: dataset_infer
batch_size: 2
type: DataLoader
data_path: "{workspace}/data/test"
data_converter: "{workspace}/reader.py"
hyper_parameters:
......@@ -26,23 +31,47 @@ hyper_parameters:
class: Adagrad
learning_rate: 0.001
is_sparse: False
dict_dim: 33257
max_len: 100
cnn_dim: 128
cnn_filter_size1: 1
cnn_filter_size2: 2
cnn_filter_size3: 3
emb_dim: 128
hid_dim: 96
class_dim: 2
mode: runner1
mode: [train_runner,infer_runner]
runner:
- name: runner1
- name: train_runner
class: train
epochs: 10
epochs: 16
device: cpu
save_checkpoint_interval: 2
save_inference_interval: 4
save_checkpoint_interval: 1
save_inference_interval: 1
save_checkpoint_path: "increment"
save_inference_path: "inference"
save_inference_feed_varnames: []
save_inference_fetch_varnames: []
init_model_path: ""
print_interval: 10
phases: phase_train
- name: infer_runner
class: infer
# device to run training or infer
device: cpu
print_interval: 1
init_model_path: "increment/14" # load model path
phases: phase_infer
phase:
- name: phase1
- name: phase_train
model: "{workspace}/model.py"
dataset_name: data1
thread_num: 1
- name: phase_infer
model: "{workspace}/model.py" # user-defined model
dataset_name: dataset_infer # select dataset by name
thread_num: 1
# encoding=utf-8
import os
import sys
def build_word_dict():
word_file = "word_dict.txt"
f = open(word_file, "r")
lines = f.readlines()
word_list_ids = range(1, len(lines) + 1)
word_dict = dict(zip([word.strip() for word in lines], word_list_ids))
f.close()
return word_dict
def build_token_data(word_dict, txt_file, token_file):
max_text_size = 100
f = open(txt_file, "r")
fout = open(token_file, "w")
lines = f.readlines()
i = 0
for line in lines:
line = line.strip("\n").split("\t")
text = line[0].strip("\n").split(" ")
tokens = []
label = line[1]
for word in text:
if word in word_dict:
tokens.append(str(word_dict[word]))
else:
tokens.append("0")
seg_len = len(tokens)
if seg_len < 5:
continue
if seg_len >= max_text_size:
tokens = tokens[:max_text_size]
seg_len = max_text_size
else:
tokens = tokens + ["0"] * (max_text_size - seg_len)
text_tokens = " ".join(tokens)
fout.write(text_tokens + " " + str(seg_len) + " " + label + "\n")
if (i + 1) % 100 == 0:
print(str(i + 1) + " lines OK")
i += 1
fout.close()
f.close()
word_dict = build_word_dict()
txt_file = "test.tsv"
token_file = "test.txt"
build_token_data(word_dict, txt_file, token_file)
txt_file = "dev.tsv"
token_file = "dev.txt"
build_token_data(word_dict, txt_file, token_file)
txt_file = "train.tsv"
token_file = "train.txt"
build_token_data(word_dict, txt_file, token_file)
5681 17044 4352 7574 16576 3574 32952 12211 18835 28961 15320 2019 21675 30604 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 14 1
9054 31881 4449 12211 12488 5975 3574 28592 2547 2547 14132 3574 24908 5975 24285 10010 3574 31872 20925 9886 12211 26530 3567 30818 19640 22506 28312 19887 12211 28212 8576 3574 28592 12306 14132 539 33049 9039 14160 113 3567 19675 5511 2111 623 12068 12211 3574 18416 12068 19680 12211 30781 21946 1525 9886 3574 28109 31201 3567 25710 30503 30781 12068 19887 12211 22052 3574 2050 5402 10217 31201 1525 9698 14160 19887 3574 26209 24908 539 33049 9039 32949 8890 29693 3566 3566 11053 30781 26853 3567 3567 0 0 0 0 0 0 0 0 92 0
19640 32771 31526 16576 13354 3574 5087 30781 7902 19037 12211 0 3574 4756 15048 11063 0 15019 16576 2019 29812 2276 22804 13275 2019 24599 12211 30294 6983 26606 1467 3574 18448 8052 16576 23091 32440 11034 16576 3574 1470 6983 1346 31382 13354 3574 11711 10074 28587 5030 19058 16576 2019 16497 6890 12223 30035 6983 1112 18448 30837 11280 24599 2019 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 64 0
7513 19838 3562 32737 15474 3562 1887 15474 0 0 18835 19813 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 12 1
30325 3574 30788 12211 25843 11533 30150 8937 11309 8690 12211 14166 2200 3574 15802 0 20424 14166 25336 113 16576 11533 24294 12211 26301 16576 3574 28592 16191 12211 8690 13743 0 517 12211 0 0 23958 3574 31019 19680 13841 15337 12211 23958 30781 28630 3574 8690 12700 11280 12211 23958 24908 20409 7481 8052 6094 4002 30245 3574 1526 9904 27032 31347 24006 12211 14166 0 9910 24908 12211 0 2019 25469 17293 27438 29774 13757 24908 22301 28505 25450 12211 14039 3574 28801 4621 4879 3574 623 9904 23958 14166 18417 4895 113 11114 2018 113 100 1
113 16576 17947 28955 12211 24253 3574 22068 30167 12211 14039 30818 28640 7801 2019 7985 30167 5402 6805 0 12211 27645 33067 30151 3574 11110 12211 10710 4549 22708 4308 24908 25975 12211 26957 0 2019 17942 25575 227 19641 1525 13129 113 15492 23224 3574 21163 15565 23273 29004 12452 13233 27573 12211 12046 2019 302 19367 16576 27914 0 0 113 12211 28035 0 13743 13330 24390 12466 1525 12537 3574 18131 2019 9315 25720 27416 2276 15038 18162 10024 28955 3574 10097 18162 26594 12211 21949 3574 30788 12133 26362 1779 27386 21017 14295 1525 454 100 1
33022 4169 19038 25096 3574 19185 113 25010 0 0 10511 17460 28972 6574 3574 1409 0 10010 3574 33022 129 16186 10511 17460 15182 3574 20235 10511 17460 11226 27150 13166 3562 18835 19038 5391 3574 22195 8052 28892 31948 10960 3574 13367 29338 15048 11030 22185 18621 28776 5205 2019 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 52 0
23439 330 0 0 29655 12211 3574 4211 3574 19650 19640 13757 3562 0 0 8990 330 0 0 18920 12211 31924 6688 31857 15364 3574 19641 30781 18416 28952 9209 12211 118 10710 16912 3562 0 0 27771 330 0 0 10126 30325 3574 15374 4348 0 6356 28420 24193 29526 12211 10523 21872 3571 24383 1580 3574 17536 1525 14745 21674 10710 4952 14871 3574 14590 20306 7695 0 32718 3562 0 0 13260 330 0 0 5847 30325 3574 25951 26995 21163 22787 15535 20889 3574 27914 5391 130 2276 15243 6356 0 16576 3562 0 0 100 1
24908 32568 24044 28952 16576 27914 28955 3574 14160 13543 16582 5536 2019 11711 3527 19675 12211 15474 3574 0 14160 31857 30927 2019 18416 9231 12486 12211 20374 3574 1111 30173 19058 3574 31857 31825 3574 30170 15501 21070 2019 31383 19640 5004 3574 31858 12211 6408 2733 8034 24870 12730 12211 16401 2019 18416 19640 9072 18416 12211 2313 12211 20374 3574 18416 2313 25575 19315 31383 20374 20161 24160 3574 11711 3527 3574 31383 20374 31857 28378 2019 1296 5402 23273 16576 2019 16497 28952 2019 9512 15038 5536 3574 11711 10486 15168 19641 21994 0 2019 100 1
0 7902 5402 29107 16576 15535 15535 15535 0 19634 21017 12211 26505 14160 15129 0 15535 15535 15535 26211 4002 9749 23360 16576 15535 15535 15535 26040 15535 15535 15535 15535 11698 32986 19641 0 22421 15535 15535 15535 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 40 0
28955 17755 3574 1735 18232 19262 12992 12230 3574 18416 30781 7388 19680 19643 16576 12211 3574 28952 9209 3574 16572 22360 2019 19680 19643 6414 12211 2011 27666 2012 3574 13757 32205 3574 14754 11280 12211 22186 7628 1827 17413 3574 19641 30781 31383 12211 4853 2019 33140 113 6047 6414 3310 31383 3574 4654 22360 6580 26147 12211 18696 2019 12306 6414 20539 3574 12680 22360 18624 8051 29384 1146 2019 18046 33188 16582 29384 12211 17311 13222 3574 18416 7453 28961 8014 3574 11711 18416 28961 17658 3574 29384 30781 19893 19643 15073 12211 32171 12211 2019 100 0
28955 12211 30964 14590 28961 4412 29183 29493 6393 17111 29183 11670 12211 19636 23233 28961 4412 29183 25469 1112 16603 14590 16720 28961 9749 32365 23958 12211 33245 1525 11271 29183 29607 4694 8052 12068 32247 26813 29183 12229 6856 3674 330 30326 972 32948 29183 18416 28961 20161 1120 19641 30054 28955 330 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 55 0
28587 26594 16393 14439 20100 8452 12211 11738 3574 20288 2276 2770 9051 29266 3574 27097 12211 0 14648 7902 5827 4308 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 22 1
19083 3561 20034 30173 8356 3574 18416 18016 6154 13757 30827 23410 4879 5213 3566 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 15 1
28587 14745 2018 1580 3574 19636 9052 14160 19683 16576 0 0 6007 5361 26370 5391 785 3574 0 17010 28587 27857 19048 20558 9051 3574 6007 0 0 22897 18323 1447 2019 0 0 32391 17536 24961 19048 9749 18448 3574 24283 6356 7648 26789 2019 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 47 0
24908 18920 1400 665 16167 12211 17293 3574 13518 28952 8393 23504 3574 31266 12211 30781 4477 2019 4654 18896 4289 13841 4822 3574 24908 27376 15243 18416 8052 20077 17493 17317 3574 14842 16949 3574 12081 28961 2276 0 14399 20158 14398 16335 12211 3699 7697 6318 69 2019 11924 8053 27376 12211 14039 3574 21210 23273 3574 1732 30818 17942 22561 3083 2019 17268 12700 28892 9108 16576 26203 19037 23872 3574 14988 31773 3574 33140 1725 24908 0 8053 8052 13841 3574 25944 0 2019 4032 5025 13841 19185 12211 14039 3574 665 0 12211 4822 6988 100 1
29728 31619 6149 5402 113 7317 11738 3574 31482 11924 16576 17657 6541 9761 3574 31224 5402 21141 3574 6356 16191 19640 14451 26154 7192 16076 3567 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 27 0
29302 11364 19059 13652 12211 3574 7898 30781 6356 7961 14954 21752 7340 2019 29302 11401 8328 3574 20384 20034 1460 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 21 0
4592 12211 31382 11030 3574 7961 6356 136 11714 31881 31478 3574 7957 11533 17413 3574 18835 14451 14550 11533 389 3574 14444 20444 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 24 1
18416 24908 0 5233 22185 12211 29183 18956 30781 9668 8904 15168 18416 16108 29183 18416 29123 4351 28845 11709 11731 30486 21200 3574 4351 32986 8052 13757 11711 16497 25138 18448 3006 30326 20837 6356 16060 11231 13757 18448 11731 29173 3576 18835 27924 11711 11533 11225 3574 17386 15934 7288 0 26216 12211 1542 3574 24908 12511 18416 16060 11231 32842 18448 11731 29173 3574 18956 9668 31387 755 32986 18416 28972 18855 30781 18448 3006 30326 20837 30781 8052 13757 15048 18448 11731 29173 12211 3574 19640 18584 18416 32986 25710 18416 2276 29173 12211 22052 24908 100 0
12 27 13 0 25 52 89 20 39 4 9 1
78 10 61 58 29 79 85 16 46 41 9 1
81 77 44 4 5 57 43 97 42 89 6 0
7 77 86 3 98 89 56 24 7 59 9 1
65 89 99 27 65 98 16 89 42 0 3 0
66 14 48 38 66 5 56 89 98 19 4 1
78 7 10 20 77 16 37 43 59 23 6 1
84 95 28 35 0 82 55 19 13 81 7 0
34 32 98 37 43 51 6 38 20 40 9 0
75 36 13 51 70 24 62 90 32 91 7 1
13 5 49 21 57 21 67 85 74 14 1 0
68 13 86 16 52 50 23 11 65 99 1 1
15 20 75 55 15 90 54 54 15 91 9 0
44 56 15 88 57 3 62 53 89 57 8 1
23 8 40 25 60 33 8 69 44 88 7 1
63 94 5 43 23 70 31 67 21 55 6 0
44 11 64 92 10 37 30 84 19 71 5 1
89 18 71 13 16 58 47 60 77 87 7 1
13 48 56 39 98 53 32 93 13 91 7 0
56 78 67 68 27 11 77 48 45 10 1 1
52 12 14 5 2 8 3 36 33 59 6 0
86 42 91 81 2 9 21 0 44 7 9 1
96 27 82 55 81 30 91 41 91 58 2 1
97 69 76 47 80 62 23 30 87 22 7 1
42 56 25 47 42 18 80 53 15 57 7 0
34 73 75 88 61 79 40 74 87 87 6 1
7 91 9 24 42 60 76 31 10 13 4 0
21 1 46 59 61 54 99 54 89 55 5 1
67 21 1 29 88 5 3 85 39 22 5 1
90 99 7 8 17 77 73 3 32 10 5 0
30 44 26 32 37 74 90 71 42 29 9 1
79 68 3 24 21 37 35 3 76 23 6 1
3 66 7 4 2 88 94 64 47 81 6 1
10 48 16 49 96 93 61 97 84 39 3 1
73 28 67 59 89 92 17 24 52 71 3 1
98 4 35 62 91 2 78 51 72 93 1 1
37 42 96 10 48 49 84 45 59 47 5 1
13 24 7 49 63 78 29 75 45 92 7 1
1 6 95 23 38 34 85 94 33 47 6 1
99 63 65 39 72 73 91 20 16 45 9 0
35 8 81 24 62 0 95 0 52 46 4 1
58 66 88 42 86 94 91 8 18 92 7 0
12 62 56 43 99 31 63 80 11 7 4 1
22 36 1 39 69 20 56 75 17 15 7 0
25 97 62 50 99 98 32 2 98 75 7 1
7 59 98 68 62 19 28 28 60 27 7 0
39 63 43 45 43 11 40 81 4 25 6 0
81 95 27 84 71 45 87 65 40 50 1 0
82 21 69 55 71 92 52 65 90 16 3 0
24 6 5 22 36 34 66 71 3 52 2 0
5 14 66 71 49 10 52 81 32 14 1 0
8 94 52 23 60 27 43 19 89 91 9 0
26 14 36 37 28 94 46 96 11 80 8 1
89 19 77 66 48 75 62 58 90 81 8 1
25 43 95 21 25 81 39 79 9 74 9 0
25 2 64 27 67 36 59 68 99 66 5 1
13 46 41 55 89 93 79 83 32 52 6 0
49 77 57 9 91 49 86 50 32 5 2 0
94 7 53 54 70 69 5 51 59 91 5 1
24 72 94 13 17 12 2 67 0 89 6 1
70 38 19 27 38 87 72 41 98 84 6 1
89 76 82 4 69 64 97 77 88 58 9 0
67 41 99 1 80 38 96 24 67 59 3 1
42 83 50 19 97 99 99 50 46 76 8 1
43 99 63 40 93 15 3 57 11 0 1 0
16 65 31 43 89 37 98 63 29 69 8 1
39 5 65 45 12 82 46 87 82 93 8 0
34 69 82 13 4 20 92 58 46 83 2 1
46 79 87 57 87 23 72 95 37 88 8 0
41 72 81 71 60 15 32 1 9 97 3 0
84 98 15 78 39 82 89 74 46 32 9 0
16 18 92 80 50 44 98 45 15 41 3 1
74 78 81 40 17 65 38 21 27 9 1 0
14 69 68 50 57 11 62 2 89 54 6 0
70 29 79 29 44 56 33 27 25 4 3 1
44 20 87 67 65 41 93 37 99 78 1 1
93 57 87 11 33 40 21 3 47 87 9 1
8 3 24 49 99 48 40 22 99 41 2 0
19 90 9 83 93 22 36 96 44 73 7 1
4 73 2 88 79 90 32 48 45 12 5 0
24 58 34 67 85 62 84 48 14 79 5 1
54 69 19 18 59 78 84 48 61 46 4 0
72 69 95 26 30 74 49 30 95 61 8 0
73 29 46 39 48 30 97 63 89 34 9 1
51 32 44 22 70 69 91 81 74 52 3 0
99 66 89 71 31 42 5 40 21 12 6 0
58 26 59 56 91 49 79 57 57 74 6 1
30 36 59 74 6 30 17 1 99 38 4 0
43 48 77 86 67 25 38 36 3 91 4 1
67 24 51 34 37 8 98 76 84 13 1 1
73 47 88 15 32 99 67 26 28 89 3 1
91 66 11 86 5 12 15 43 79 89 1 1
15 60 43 58 61 0 62 32 98 29 9 0
80 36 78 42 70 52 2 10 42 41 6 1
36 16 46 34 96 39 8 21 86 54 5 1
80 72 13 1 28 49 73 90 81 34 1 0
73 64 86 9 94 49 44 38 47 64 2 0
69 90 69 36 60 45 39 7 41 72 8 0
31 86 54 82 81 77 93 99 68 63 1 1
95 76 97 36 40 12 4 95 59 64 4 1
88 20 64 40 27 11 96 40 41 73 6 0
28 72 70 43 34 54 98 43 29 63 5 0
78 72 4 47 47 38 73 8 65 40 3 1
91 64 51 93 8 78 53 15 42 32 4 0
34 36 45 9 16 0 51 40 90 29 2 1
80 93 65 80 11 19 26 61 29 8 4 0
94 11 60 36 58 98 43 90 64 1 1 0
42 54 89 86 80 72 81 48 19 67 5 0
81 25 30 60 59 20 75 38 75 29 6 0
84 16 48 28 23 20 53 13 32 90 1 0
58 31 77 68 27 88 51 97 70 93 8 1
63 67 85 6 35 22 28 65 8 7 3 0
54 75 93 58 98 9 15 37 61 38 6 1
56 24 50 62 63 47 9 4 58 30 8 1
64 91 32 68 50 90 51 86 52 6 1 1
55 50 46 41 28 1 11 39 75 9 1 0
23 27 98 73 25 7 89 48 7 44 4 1
86 98 68 1 74 46 15 92 59 25 9 1
95 86 72 13 33 60 62 83 96 84 1 0
9 58 37 50 57 16 78 0 21 80 2 0
82 94 74 42 3 60 61 93 34 22 3 1
16 97 97 14 47 50 90 35 9 58 5 0
70 94 82 42 85 88 59 58 6 68 9 0
14 58 24 44 8 29 12 18 26 80 7 0
22 23 7 82 39 28 96 92 23 40 5 1
40 31 72 94 20 81 89 4 42 1 5 0
57 63 71 41 28 2 39 67 90 54 6 0
9 74 4 41 11 31 15 21 44 32 6 1
31 28 66 66 61 78 72 80 82 88 3 1
79 18 1 59 35 62 0 72 78 97 7 0
14 19 30 63 38 37 12 15 54 15 6 1
54 91 37 79 60 35 55 62 94 84 7 1
10 55 78 96 45 55 35 56 54 70 6 1
23 46 15 93 66 11 32 45 74 25 4 0
51 55 9 9 88 59 21 66 87 12 1 1
90 22 38 66 12 9 30 48 55 85 1 1
39 23 82 29 57 76 79 56 3 19 2 0
7 72 76 15 90 23 40 40 33 39 4 1
60 64 34 11 18 18 38 39 53 37 1 1
85 72 51 47 83 90 32 96 78 23 9 1
85 51 96 31 83 70 57 65 15 0 6 0
41 11 56 94 40 6 62 86 68 83 7 0
34 82 44 30 2 2 94 62 41 27 6 1
54 86 50 83 76 65 0 87 80 70 7 0
97 50 65 78 2 90 28 5 12 56 5 1
34 19 68 93 11 9 14 87 22 70 9 0
63 77 27 20 20 37 65 51 29 29 9 1
22 79 98 57 56 97 43 49 4 80 4 1
6 4 35 54 4 36 1 79 85 35 6 0
12 55 68 61 91 43 49 5 93 27 8 0
64 22 69 16 63 20 28 60 13 35 7 1
9 19 60 89 62 29 47 33 6 13 4 0
14 15 39 86 47 75 7 70 57 60 6 1
90 63 12 43 28 46 39 97 83 42 6 0
49 3 3 64 59 46 30 13 61 10 2 0
79 47 29 47 54 38 50 66 18 63 5 1
98 67 1 22 66 32 91 77 63 33 3 0
72 22 10 27 28 44 29 66 71 1 7 0
20 52 19 23 9 38 1 93 83 73 5 0
88 57 22 64 93 66 20 90 78 2 7 1
90 86 41 28 14 25 86 73 7 21 4 0
63 91 0 29 2 78 86 76 9 20 4 1
3 57 91 37 21 85 80 99 18 79 1 1
69 95 36 6 85 47 83 83 61 52 4 0
72 4 34 16 59 78 56 70 27 44 9 1
58 42 6 53 21 7 83 38 86 66 5 0
22 86 22 21 86 22 83 38 62 19 4 0
14 63 20 53 98 76 10 22 35 76 9 1
16 88 13 66 37 33 11 40 61 97 2 1
60 9 98 35 51 11 98 73 67 26 6 1
25 48 87 93 58 58 15 9 23 13 7 1
61 47 47 36 97 22 63 35 9 38 5 1
94 49 41 38 0 81 59 39 13 65 3 0
88 82 71 96 76 16 57 24 72 36 5 1
28 46 8 95 94 86 63 1 42 63 6 0
12 95 29 66 64 77 19 26 73 53 4 0
19 5 52 34 13 62 6 4 25 58 5 0
18 39 39 56 73 29 5 15 13 82 1 1
50 66 99 67 76 25 43 12 24 67 9 0
74 56 61 97 23 63 22 63 6 83 2 1
10 96 13 49 43 20 58 19 99 58 7 1
2 95 31 4 99 91 27 90 85 32 3 0
41 23 20 71 41 75 75 35 16 12 3 1
21 33 87 57 19 27 94 36 80 10 6 0
8 0 25 74 14 61 86 8 42 82 9 0
23 33 91 19 84 99 95 92 29 31 8 0
94 94 5 6 98 23 37 65 14 25 6 1
42 16 39 32 2 20 86 81 90 91 8 0
72 39 20 63 88 52 65 81 77 96 4 0
48 73 65 75 89 36 75 36 11 35 8 0
79 74 3 29 63 20 76 46 8 82 5 0
7 46 38 77 79 92 71 98 30 35 6 0
44 69 93 31 22 68 91 70 32 86 5 0
45 38 77 87 64 44 69 19 28 82 9 0
93 63 92 84 22 44 51 94 4 99 9 0
77 10 49 29 59 55 44 7 95 39 2 0
10 85 99 9 91 29 64 14 50 24 6 1
74 4 21 12 77 36 71 51 50 31 9 1
66 76 28 18 23 49 33 31 6 44 1 1
92 50 90 64 95 58 93 4 78 88 6 1
69 79 76 47 46 26 30 40 33 58 8 1
97 12 87 82 6 18 57 49 49 58 1 1
70 79 55 86 29 88 55 39 17 74 5 1
65 51 45 62 54 17 59 12 29 79 5 0
5 63 82 51 54 97 54 36 57 46 3 0
74 77 52 10 12 9 34 95 2 0 5 0
50 20 22 89 50 70 55 98 80 50 1 0
61 80 7 3 78 36 44 37 90 18 9 0
81 13 55 57 88 81 66 55 18 34 2 1
52 30 54 70 28 56 48 82 67 20 8 1
0 41 15 63 27 90 12 16 56 79 3 0
69 89 54 1 93 10 15 2 25 59 8 0
74 99 17 93 96 82 38 77 98 85 4 0
8 59 17 92 60 21 59 76 55 73 2 1
53 56 79 19 29 94 86 96 62 39 3 1
23 44 25 63 41 94 65 10 8 40 9 1
7 18 80 43 20 70 14 59 72 17 9 0
84 97 79 14 37 64 23 68 8 24 2 0
63 94 98 77 8 62 10 77 63 56 4 0
8 63 74 34 49 22 52 54 44 93 3 0
94 48 92 58 82 48 53 34 96 25 2 0
33 15 3 95 48 93 9 69 44 77 7 1
69 72 80 77 64 24 52 21 36 49 2 0
59 34 54 66 60 19 76 79 16 70 5 1
8 83 9 91 67 79 31 20 31 88 2 0
64 95 46 95 78 63 4 60 66 63 7 1
10 39 78 45 36 4 89 94 68 75 7 0
81 52 70 11 48 15 40 63 29 14 8 1
94 49 30 14 53 12 53 42 77 82 8 1
40 88 46 20 54 84 76 15 2 73 2 1
71 50 79 54 17 58 30 16 17 99 1 1
74 79 74 61 61 36 28 39 89 36 6 0
53 45 45 23 51 32 93 26 10 8 3 0
1 97 6 67 88 20 41 63 49 6 8 0
3 64 41 19 41 80 75 71 69 90 8 0
31 90 38 93 52 0 38 86 41 68 9 1
50 94 53 9 73 59 94 7 24 57 3 0
87 11 4 62 96 7 0 59 46 11 6 1
77 67 56 88 45 62 10 51 86 27 6 1
62 62 59 99 83 84 79 97 56 37 5 0
19 55 0 37 44 44 2 7 54 50 5 1
23 60 11 83 6 48 20 77 54 31 6 0
27 53 52 30 3 70 57 38 47 96 5 0
75 14 5 83 72 46 47 64 14 12 7 0
29 95 36 63 59 49 38 44 13 15 2 1
38 3 70 89 2 94 89 74 33 6 8 1
28 56 49 43 83 34 7 63 36 13 7 0
25 90 23 85 50 65 36 10 64 38 5 0
35 94 48 38 99 71 42 39 61 75 8 1
28 73 34 22 51 8 52 98 74 19 8 1
12 40 65 12 7 96 73 65 12 90 5 0
42 42 48 16 80 14 48 29 29 45 5 0
58 20 4 0 69 99 15 4 16 4 1 1
93 30 90 5 23 63 25 30 99 32 7 1
91 23 20 26 84 78 58 76 58 90 5 1
33 2 36 59 55 9 79 34 92 57 9 0
80 63 84 73 22 40 70 94 59 34 5 0
49 95 50 32 90 22 18 66 46 32 2 0
47 72 3 94 33 78 87 43 11 67 5 0
76 44 86 81 95 48 79 46 11 65 8 1
59 51 97 75 17 5 40 59 32 62 6 0
41 13 58 7 54 84 8 84 27 55 1 0
24 80 44 26 86 99 68 80 81 22 9 0
12 45 16 44 66 76 33 53 3 20 9 0
22 3 79 6 32 38 75 66 15 25 9 1
51 48 26 53 33 26 18 74 9 39 5 1
35 67 89 91 29 81 23 52 19 11 6 0
64 50 43 1 43 49 19 20 84 19 8 0
34 4 9 77 24 61 55 82 42 76 9 0
37 84 94 33 67 60 3 95 78 8 9 0
82 10 54 12 47 23 78 97 6 51 5 0
70 40 38 47 5 38 83 70 37 90 2 0
42 21 62 27 43 47 82 80 88 49 4 0
68 68 67 12 38 13 32 30 93 27 3 1
5 44 98 28 5 81 20 56 10 34 9 1
40 46 11 33 73 62 68 70 66 85 4 0
9 46 11 84 6 31 18 89 66 32 1 1
6 78 44 98 77 29 69 39 62 78 1 0
47 90 18 0 3 8 12 20 51 75 4 1
21 29 74 19 12 29 41 22 63 47 8 1
22 59 64 62 18 89 19 92 87 8 8 0
6 21 24 58 14 53 18 93 62 15 8 0
20 33 88 25 37 52 1 72 74 11 2 0
90 49 28 53 28 80 22 81 0 46 9 0
87 31 51 27 15 31 68 93 5 4 7 1
21 72 60 2 24 79 22 24 77 61 9 0
20 4 6 40 28 14 16 78 58 99 7 1
80 35 98 20 91 35 47 29 3 19 2 1
57 21 24 61 60 39 83 34 53 2 2 0
74 86 78 78 18 44 20 94 85 71 4 1
27 48 44 92 10 18 74 54 25 85 2 0
74 77 28 75 74 91 69 36 95 68 7 0
32 84 17 18 55 79 59 57 21 69 2 1
69 77 40 98 83 40 4 66 39 83 1 1
63 24 32 39 75 92 81 49 2 51 5 1
35 40 84 71 3 16 82 91 44 52 8 0
21 78 66 4 57 27 21 89 4 34 7 1
94 18 57 49 88 26 29 76 56 67 6 0
14 91 71 30 5 36 28 74 16 73 3 1
93 36 43 46 77 44 59 19 56 84 3 0
11 16 2 67 11 96 20 91 20 59 2 1
72 79 26 99 90 71 56 46 35 99 3 0
29 87 20 40 13 14 14 40 61 27 6 0
41 64 28 51 56 52 87 67 37 91 6 1
33 14 5 30 99 54 27 80 54 55 4 1
60 44 73 91 71 53 54 95 59 81 6 0
69 33 11 83 4 53 34 39 43 84 1 0
73 31 19 4 50 20 66 73 94 88 4 0
30 49 41 76 5 21 88 69 76 3 2 0
18 50 27 76 67 38 87 16 52 87 5 1
33 36 80 8 43 82 89 76 37 3 5 0
98 21 61 24 58 13 9 85 56 74 1 1
84 27 50 96 9 56 30 31 85 65 1 1
65 74 40 2 8 40 18 57 30 38 1 1
76 44 64 6 10 32 84 70 74 24 1 1
14 29 59 34 27 8 0 37 27 68 3 0
6 47 5 77 15 41 93 49 59 83 4 1
39 88 43 89 32 98 82 0 5 12 9 0
78 79 30 26 58 6 9 58 37 65 8 1
25 28 66 41 70 87 76 62 29 39 7 1
......@@ -20,28 +20,32 @@ from paddlerec.core.model import ModelBase
class Model(ModelBase):
def __init__(self, config):
ModelBase.__init__(self, config)
self.dict_dim = 100
self.max_len = 10
self.cnn_dim = 32
self.cnn_filter_size = 128
self.emb_dim = 8
self.hid_dim = 128
self.class_dim = 2
self.is_sparse = envs.get_global_env("hyper_parameters.is_sparse",
False)
self.dict_dim = envs.get_global_env("hyper_parameters.dict_dim")
self.max_len = envs.get_global_env("hyper_parameters.max_len")
self.cnn_dim = envs.get_global_env("hyper_parameters.cnn_dim")
self.cnn_filter_size1 = envs.get_global_env(
"hyper_parameters.cnn_filter_size1")
self.cnn_filter_size2 = envs.get_global_env(
"hyper_parameters.cnn_filter_size2")
self.cnn_filter_size3 = envs.get_global_env(
"hyper_parameters.cnn_filter_size3")
self.emb_dim = envs.get_global_env("hyper_parameters.emb_dim")
self.hid_dim = envs.get_global_env("hyper_parameters.hid_dim")
self.class_dim = envs.get_global_env("hyper_parameters.class_dim")
self.is_sparse = envs.get_global_env("hyper_parameters.is_sparse")
def input_data(self, is_infer=False, **kwargs):
data = fluid.data(
name="input", shape=[None, self.max_len], dtype='int64')
label = fluid.data(name="label", shape=[None, 1], dtype='int64')
seq_len = fluid.data(name="seq_len", shape=[None], dtype='int64')
return [data, label, seq_len]
label = fluid.data(name="label", shape=[None, 1], dtype='int64')
return [data, seq_len, label]
def net(self, input, is_infer=False):
""" network definition """
data = input[0]
label = input[1]
seq_len = input[2]
seq_len = input[1]
label = input[2]
# embedding layer
emb = fluid.embedding(
......@@ -50,15 +54,31 @@ class Model(ModelBase):
is_sparse=self.is_sparse)
emb = fluid.layers.sequence_unpad(emb, length=seq_len)
# convolution layer
conv = fluid.nets.sequence_conv_pool(
conv1 = fluid.nets.sequence_conv_pool(
input=emb,
num_filters=self.cnn_dim,
filter_size=self.cnn_filter_size1,
act="tanh",
pool_type="max")
conv2 = fluid.nets.sequence_conv_pool(
input=emb,
num_filters=self.cnn_dim,
filter_size=self.cnn_filter_size,
filter_size=self.cnn_filter_size2,
act="tanh",
pool_type="max")
conv3 = fluid.nets.sequence_conv_pool(
input=emb,
num_filters=self.cnn_dim,
filter_size=self.cnn_filter_size3,
act="tanh",
pool_type="max")
convs_out = fluid.layers.concat(input=[conv1, conv2, conv3], axis=1)
# full connect layer
fc_1 = fluid.layers.fc(input=[conv], size=self.hid_dim)
fc_1 = fluid.layers.fc(input=convs_out, size=self.hid_dim, act="tanh")
# softmax layer
prediction = fluid.layers.fc(input=[fc_1],
size=self.class_dim,
......@@ -70,5 +90,7 @@ class Model(ModelBase):
self._cost = avg_cost
if is_infer:
self._infer_results["acc"] = acc
self._infer_results["loss"] = avg_cost
else:
self._metrics["acc"] = acc
self._metrics["loss"] = avg_cost
......@@ -23,9 +23,10 @@ class Reader(ReaderBase):
def _process_line(self, l):
l = l.strip().split()
data = l[0:10]
seq_len = l[10:11]
label = l[11:]
data = l[0:100]
seq_len = l[100:101]
label = l[101:]
return data, label, seq_len
def generate_sample(self, line):
......@@ -37,6 +38,6 @@ class Reader(ReaderBase):
data = [int(i) for i in data]
label = [int(i) for i in label]
seq_len = [int(i) for i in seq_len]
yield [('data', data), ('label', label), ('seq_len', seq_len)]
yield [('data', data), ('seq_len', seq_len), ('label', label)]
return data_iter
# classification文本分类模型
以下是本例的简要目录结构及说明:
```
├── data #样例数据
├── train
├── train.txt #训练数据样例
├── test
├── test.txt #测试数据样例
├── preprocess.py #数据处理程序
├── __init__.py
├── README.md #文档
├── model.py #模型文件
├── config.yaml #配置文件
├── reader.py #读取程序
```
注:在阅读该示例前,建议您先了解以下内容:
[paddlerec入门教程](https://github.com/PaddlePaddle/PaddleRec/blob/master/README.md)
## 内容
- [模型简介](#模型简介)
- [数据准备](#数据准备)
- [运行环境](#运行环境)
- [快速开始](#快速开始)
- [效果复现](#效果复现)
- [进阶使用](#进阶使用)
- [FAQ](#FAQ)
## 模型简介
TextCNN网络是2014年提出的用来做文本分类的卷积神经网络,由于其结构简单、效果好,在文本分类、推荐等NLP领域应用广泛。对于文本分类问题,常见的方法无非就是抽取文本的特征。然后再基于抽取的特征训练一个分类器。 然而研究证明,TextCnn在文本分类问题上有着更加卓越的表现。从直观上理解,TextCNN通过一维卷积来获取句子中n-gram的特征表示。TextCNN对文本浅层特征的抽取能力很强,在短文本领域专注于意图分类时效果很好,应用广泛,且速度较快。
Yoon Kim在论文[EMNLP 2014][Convolutional neural networks for sentence classication](https://www.aclweb.org/anthology/D14-1181.pdf)提出了TextCNN并给出基本的结构。将卷积神经网络CNN应用到文本分类任务,利用多个不同size的kernel来提取句子中的关键信息(类似于多窗口大小的ngram),从而能够更好地捕捉局部相关性。模型的主体结构如图所示:
<p align="center">
<img align="center" src="../../../doc/imgs/cnn-ckim2014.png">
<p>
## 数据准备
情感倾向分析(Sentiment Classification,简称Senta)针对带有主观描述的中文文本,可自动判断该文本的情感极性类别并给出相应的置信度。情感类型分为积极、消极。情感倾向分析能够帮助企业理解用户消费习惯、分析热点话题和危机舆情监控,为企业提供有利的决策支持。
情感是人类的一种高级智能行为,为了识别文本的情感倾向,需要深入的语义建模。另外,不同领域(如餐饮、体育)在情感的表达各不相同,因而需要有大规模覆盖各个领域的数据进行模型训练。为此,我们通过基于深度学习的语义模型和大规模数据挖掘解决上述两个问题。效果上,我们基于开源情感倾向分类数据集ChnSentiCorp进行评测,模型在测试集上的准确率如表所示:
| 模型 | dev | test |
| :------| :------ | :------
| TextCNN | 90.75% | 92.19% |
您可以直接执行以下命令下载我们分词完毕后的数据集,文件解压之后,senta_data目录下会存在训练数据(train.tsv)、开发集数据(dev.tsv)、测试集数据(test.tsv)以及对应的词典(word_dict.txt):
```
wget https://baidu-nlp.bj.bcebos.com/sentiment_classification-dataset-1.0.0.tar.gz
tar -zxvf sentiment_classification-dataset-1.0.0.tar.gz
```
数据格式为一句中文的评价语句,和一个代表情感信息的标签。两者之间用/t分隔,中文的评价语句已经分词,词之间用空格分隔。
```
15.4寸 笔记本 的 键盘 确实 爽 , 基本 跟 台式机 差不多 了 , 蛮 喜欢 数字 小 键盘 , 输 数字 特 方便 , 样子 也 很 美观 , 做工 也 相当 不错 1
跟 心灵 鸡汤 没 什么 本质 区别 嘛 , 至少 我 不 喜欢 这样 读 经典 , 把 经典 都 解读 成 这样 有点 去 中国 化 的 味道 了 0
```
## 运行环境
PaddlePaddle>=1.7.2
python 2.7/3.5/3.6/3.7
PaddleRec >=0.1
os : windows/linux/macos
## 快速开始
本文提供了样例数据可以供您快速体验,在paddlerec目录下直接执行下面的命令即可启动训练:
```
python -m paddlerec.run -m models/contentunderstanding/classification/config.yaml
```
## 效果复现
为了方便使用者能够快速的跑通每一个模型,我们在每个模型下都提供了样例数据。如果需要复现readme中的效果,请按如下步骤依次操作即可。
1. 确认您当前所在目录为PaddleRec/models/contentunderstanding/classification
2. 下载并解压数据集,命令如下:
```
wget https://baidu-nlp.bj.bcebos.com/sentiment_classification-dataset-1.0.0.tar.gz
tar -zxvf sentiment_classification-dataset-1.0.0.tar.gz
```
3. 本文提供了快速将数据集中的汉字数据处理为可训练格式数据的脚本,您在解压数据集后,将preprocess.py复制到senta_data文件中并执行,即可将数据集中提供的dev.tsv,test.tsv,train.tsv转化为可直接训练的dev.txt,test.txt,train.txt.
```
cp ./data/preprocess.py ./senta_data/
cd senta_data/
python preprocess.py
```
4. 创建存放训练集和测试集的目录,将数据放入目录中。
```
mkdir train
mv train.txt train
mkdir test
mv dev.txt test
cd ..
```
5. 打开文件config.yaml,更改其中的参数
将workspace改为您当前的绝对路径。(可用pwd命令获取绝对路径)
将data1下的batch_size值从10改为128
将data1下的data_path改为:{workspace}/senta_data/train
将dataset_infer下的batch_size值从2改为256
将dataset_infer下的data_path改为:{workspace}/senta_data/test
6. 执行命令,开始训练:
```
python -m paddlerec.run -m ./config.yaml
```
7. 运行结果:
```
PaddleRec: Runner infer_runner Begin
Executor Mode: infer
processor_register begin
Running SingleInstance.
Running SingleNetwork.
Running SingleInferStartup.
Running SingleInferRunner.
load persistables from increment/14
batch: 1, acc: [0.91796875], loss: [0.2287855]
batch: 2, acc: [0.91796875], loss: [0.22827303]
batch: 3, acc: [0.90234375], loss: [0.27907994]
```
## 进阶使用
## FAQ
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册