未验证 提交 6e509429 编写于 作者: 王肖 提交者: GitHub

add similarity_net dygraph (#4289)

* Update README.md (#4267)

* test=develop (#4269)

* 3d use new api (#4275)

* PointNet++ and PointRCNN use new API

* Update Readme of Dygraph BERT (#4277)

Fix some typos.

* Update run_classifier_multi_gpu.sh (#4279)

remove the CUDA_VISIBLE_DEVICES

* Update README.md (#4280)

* add similarity_net dygraph
Co-authored-by: Npkpk <xiyzhouang@gmail.com>
Co-authored-by: NKaipeng Deng <dengkaipeng@baidu.com>
上级 c4f3ebc3
#!/usr/bin/env bash
export FLAGS_enable_parallel_graph=1
export FLAGS_sync_nccl_allreduce=1
export FLAGS_fraction_of_gpu_memory_to_use=0.95
TASK_NAME='simnet'
TRAIN_DATA_PATH=./data/train_pointwise_data
VALID_DATA_PATH=./data/test_pointwise_data
TEST_DATA_PATH=./data/test_pointwise_data
INFER_DATA_PATH=./data/infer_data
VOCAB_PATH=./data/term2id.dict
CKPT_PATH=./model_files
TEST_RESULT_PATH=./test_result
INFER_RESULT_PATH=./infer_result
TASK_MODE='pointwise'
CONFIG_PATH=./config/bow_pointwise.json
INIT_CHECKPOINT=./model_files/simnet_bow_pointwise_pretrained_model/
# run_train
train() {
python run_classifier.py \
--task_name ${TASK_NAME} \
--use_cuda True \
--do_train True \
--do_valid True \
--do_test True \
--do_infer False \
--batch_size 128 \
--train_data_dir ${TRAIN_DATA_PATH} \
--valid_data_dir ${VALID_DATA_PATH} \
--test_data_dir ${TEST_DATA_PATH} \
--infer_data_dir ${INFER_DATA_PATH} \
--output_dir ${CKPT_PATH} \
--config_path ${CONFIG_PATH} \
--vocab_path ${VOCAB_PATH} \
--epoch 3 \
--save_steps 1000 \
--validation_steps 100 \
--compute_accuracy False \
--lamda 0.958 \
--task_mode ${TASK_MODE} \
--enable_ce
}
export CUDA_VISIBLE_DEVICES=0
train | python _ce.py
sleep 20
export CUDA_VISIBLE_DEVICES=0,1,2,3
train | python _ce.py
# 短文本语义匹配
## 简介
### 任务说明
短文本语义匹配(SimilarityNet, SimNet)是一个计算短文本相似度的框架,可以根据用户输入的两个文本,计算出相似度得分。SimNet框架在百度各产品上广泛应用,主要包括BOW、CNN、RNN、MMDNN等核心网络结构形式,提供语义相似度计算训练和预测框架,适用于信息检索、新闻推荐、智能客服等多个应用场景,帮助企业解决语义匹配问题。可通过[AI开放平台-短文本相似度](https://ai.baidu.com/tech/nlp_basic/simnet)线上体验。
同时推荐用户参考[ IPython Notebook demo](https://aistudio.baidu.com/aistudio/projectDetail/124373)
### 效果说明
基于百度海量搜索数据,我们训练了一个SimNet-BOW-Pairwise语义匹配模型,在一些真实的FAQ问答场景中,该模型效果比基于字面的相似度方法AUC提升5%以上,我们基于百度自建测试集(包含聊天、客服等数据集)和语义匹配数据集(LCQMC)进行评测,效果如下表所示。LCQMC数据集以Accuracy为评测指标,而pairwise模型的输出为相似度,因此我们采用0.958作为分类阈值,相比于基线模型中网络结构同等复杂的CBOW模型(准确率为0.737),我们模型的准确率为0.7532。
| 模型 | 百度知道 | ECOM |QQSIM | UNICOM |
|:-----------:|:-------------:|:-------------:|:-------------:|:-------------:|
| | AUC | AUC | AUC|正逆序比|
|BOW_Pairwise|0.6767|0.7329|0.7650|1.5630|
#### 测试集说明
| 数据集 | 来源 | 垂类 |
|:-----------:|:-------------:|:-------------:|
|百度知道 | 百度知道问题 | 日常 |
|ECOM|商业问句|金融|
|QQSIM|闲聊对话|日常|
|UNICOM|联通客服|客服|
## 快速开始
#### 版本依赖
本项目依赖于 Paddlepaddle Fluid 1.6,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装。
python版本依赖python 2.7
#### 安装代码
克隆工具集代码库到本地
```shell
git clone https://github.com/PaddlePaddle/models.git
cd models/PaddleNLP/similarity_net
```
#### 数据准备
下载经过预处理的数据,运行命令后,data目录下会存在训练集数据示例、集数据示例、测试集数据示例,以及对应词索引字典(term2id.dict)。
```shell
sh download_data.sh
```
或者
```
python download.py dataset
```
#### 模型准备
我们开源了基于大规模数据训练好的```pairwise```模型(基于bow模型训练),用户可以通过运行命令下载预训练好的模型,该模型将保存在```./model_files/simnet_bow_pairwise_pretrained_model/```下。
```shell
sh download_pretrained_model.sh
```
或者
```
python download.py model
```
#### 评估
我们公开了自建的测试集,包括百度知道、ECOM、QQSIM、UNICOM四个数据集,基于上面的预训练模型,用户可以进入evaluate目录下依次执行下列命令获取测试集评估结果。
```shell
sh evaluate_ecom.sh
sh evaluate_qqsim.sh
sh evaluate_zhidao.sh
sh evaluate_unicom.sh
```
用户也可以指定./run.sh中的TEST_DATA_PATH的值,通过下列命令评估自己指定的测试集。
```shell
sh run.sh eval
```
#### 推测
基于上面的预训练模型,可以运行下面的命令进行推测,并保存推测结果到本地。
```shell
sh run.sh infer
```
#### 训练与验证
用户可以基于示例数据构建训练集和开发集,可以运行下面的命令,进行模型训练和开发集验证。
```shell
sh run.sh train
```
用户也可以指定./run.sh中train()函数里的INIT_CHECKPOINT的值,载入训练好的模型进行热启动训练。
## 进阶使用
### 任务定义与建模
传统的文本匹配技术如信息检索中的向量空间模型 VSM、BM25 等算法,主要解决词汇层面的相似度问题,这种方法的效果在实际应用中受到语言的多义词和语言结构等问题影响。SimNet 在语义表示上沿袭了隐式连续向量表示的方式,但对语义匹配问题在深度学习框架下进行了 End-to-End 的建模,将```point-wise``````pair-wise```两种有监督学习方式全部统一在一个整体框架内。在实际应用场景下,将海量的用户点击行为数据转化为大规模的弱标记数据,在网页搜索任务上的初次使用即展现出极大威力,带来了相关性的明显提升。
### 模型原理介绍
SimNet如下图所示:
<p align="center">
<img src="./struct.jpg"/> <br />
</p>
### 数据格式说明
训练模式一共分为```pairwise``````pointwise```两种模式。
#### pairwise模式:
训练集格式如下: query \t pos_query \t neg_query。
query、pos_query和neg_query是以空格分词的中文文本,中间使用制表符'\t'隔开,pos_query表示与query相似的正例,neg_query表示与query不相似的随机负例,文本编码为utf-8。</br>
```
现在 安卓模拟器 哪个 好 用 电脑 安卓模拟器 哪个 更好 电信 手机 可以 用 腾讯 大王 卡 吗 ?
土豆 一亩地 能 收 多少 斤 一亩 地土豆 产 多少 斤 一亩 地 用 多少 斤 土豆 种子
```
开发集和测试集格式:query1 \t query2 \t label。</br>
query1和query2表示以空格分词的中文文本,label为0或1,1表示query1与query2相似,0表示query1与query2不相似,query1、query2和label中间以制表符'\t'隔开,文本编码为utf-8。</br>
```
现在 安卓模拟器 哪个 好 用 电脑 安卓模拟器 哪个 更好 1
为什么 头发 掉 得 很厉害 我 头发 为什么 掉 得 厉害 1
常喝 薏米 水 有 副 作用 吗 女生 可以 长期 喝 薏米 水养生 么 0
长 的 清新 是 什么 意思 小 清新 的 意思 是 什么 0
```
#### pointwise模式:
训练集、开发集和测试集数据格式相同:query1和query2表示以空格分词的中文文本,label为0或1,1表示query1与query2相似,0表示query1与query2不相似,query1、query2和label中间以制表符'\t'隔开,文本编码为utf-8。
```
现在 安卓模拟器 哪个 好 用 电脑 安卓模拟器 哪个 更好 1
为什么 头发 掉 得 很厉害 我 头发 为什么 掉 得 厉害 1
常喝 薏米 水 有 副 作用 吗 女生 可以 长期 喝 薏米 水养生 么 0
长 的 清新 是 什么 意思 小 清新 的 意思 是 什么 0
```
#### infer数据集:
```pairwise```和```pointwise```的infer数据集格式相同:query1 \t query2。</br>
query1和query2为以空格分词的中文文本。
```
怎么 调理 湿热 体质 ? 湿热 体质 怎样 调理 啊
搞笑 电影 美国 搞笑 的 美国 电影
```
__注__:本项目额外提供了分词预处理脚本(在preprocess目录下),可供用户使用,具体使用方法如下:
```shell
python tokenizer.py --test_data_dir ./test.txt.utf8 --batch_size 1 > test.txt.utf8.seg
```
其中test.txt.utf8为待分词的文件,一条文本数据一行,utf8编码,分词结果存放在test.txt.utf8.seg文件中
### 代码结构说明
```text
.
├── run_classifier.py:该项目的主函数,封装包括训练、预测、评估的部分
├── config.py:定义该项目模型的配置类,读取具体模型类别、以及模型的超参数等
├── reader.py:定义了读入数据的相关函数
├── utils.py:定义了其他常用的功能函数
├── Config: 定义多种模型的配置文件
├── download.py: 下载数据及预训练模型脚本
```
### 如何训练
```shell
python run_classifier.py \
--task_name ${TASK_NAME} \
--use_cuda false \ #是否使用GPU
--do_train True \ #是否训练
--do_valid True \ #是否在训练中测试开发集
--do_test True \ #是否验证测试集
--do_infer False \ #是否预测
--batch_size 128 \ #batch_size的值
--train_data_dir ${TRAIN_DATA_kPATH} \ #训练集的路径
--valid_data_dir ${VALID_DATA_PATH} \ #开发集的路径
--test_data_dir ${TEST_DATA_PATH} \ #测试集的路径
--infer_data_dir ${INFER_DATA_PATH} \ #待推测数据的路径
--output_dir ${CKPT_PATH} \ #模型存放的路径
--config_path ${CONFIG_PATH} \ #配置文件路径
--vocab_path ${VOCAB_PATH} \ #字典路径
--epoch 10 \ #epoch值
--save_steps 1000 \ #每save_steps保存一次模型
--validation_steps 100 \ #每validation_steps验证一次开发集结果
--task_mode ${TASK_MODE} #训练模式,pairwise或pointwise,与相应的配置文件匹配。
--compute_accuracy False \ #是否计算accuracy
--lamda 0.91 \ #pairwise模式计算accuracy时的阈值
--init_checkpoint "" #预加载模型路径
```
### 如何组建自己的模型
用户可以根据自己的需求,组建自定义的模型,具体方法如下所示:
i. 定义自己的网络结构
用户可以在```./models/```下定义自己的模型;
ii. 更改模型配置
用户仿照```config```中的文件生成自定义模型的配置文件。
用户需要保留配置文件中的```net```、```loss```、```optimizer```、```task_mode```和```model_path```字段。```net```为用户自定义的模型参数,```task_mode```表示训练模式,为```pairwise```或```pointwise```,要与训练命令中的```--task_mode```命令保持一致,```model_path```为模型保存路径,```loss```和```optimizer```依据自定义模型的需要仿照```config```下的其他文件填写。
iii.模型训练,运行训练、评估、预测脚本即可(具体方法同上)。
## 其他
### 如何贡献代码
如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。
# this file is only used for continuous evaluation test!
import os
import sys
sys.path.append(os.environ['ceroot'])
from kpi import CostKpi
from kpi import DurationKpi
from kpi import AccKpi
each_step_duration_simnet_card1 = DurationKpi('each_step_duration_simnet_card1', 0.03, 0, actived=True)
train_loss_simnet_card1 = CostKpi('train_loss_simnet_card1', 0.01, 0, actived=True)
each_step_duration_simnet_card4 = DurationKpi('each_step_duration_simnet_card4', 0.02, 0, actived=True)
train_loss_simnet_card4 = CostKpi('train_loss_simnet_card4', 0.01, 0, actived=True)
tracking_kpis = [
each_step_duration_simnet_card1,
train_loss_simnet_card1,
each_step_duration_simnet_card4,
train_loss_simnet_card4,
]
def parse_log(log):
'''
This method should be implemented by model developers.
The suggestion:
each line in the log should be key, value, for example:
"
train_cost\t1.0
test_cost\t1.0
train_cost\t1.0
train_cost\t1.0
train_acc\t1.2
"
'''
for line in log.split('\n'):
fs = line.strip().split('\t')
print(fs)
if len(fs) == 3 and fs[0] == 'kpis':
kpi_name = fs[1]
kpi_value = float(fs[2])
yield kpi_name, kpi_value
def log_to_ce(log):
kpi_tracker = {}
for kpi in tracking_kpis:
kpi_tracker[kpi.name] = kpi
for (kpi_name, kpi_value) in parse_log(log):
print(kpi_name, kpi_value)
kpi_tracker[kpi_name].add_record(kpi_value)
kpi_tracker[kpi_name].persist()
if __name__ == '__main__':
log = sys.stdin.read()
log_to_ce(log)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SimNet config
"""
import six
import json
import io
class SimNetConfig(object):
"""
simnet Config
"""
def __init__(self, args):
self.task_mode = args.task_mode
self.config_path = args.config_path
self._config_dict = self._parse(args.config_path)
def _parse(self, config_path):
try:
with io.open(config_path) as json_file:
config_dict = json.load(json_file)
except Exception:
raise IOError("Error in parsing simnet model config file '%s'" % config_path)
else:
if config_dict["task_mode"] != self.task_mode:
raise ValueError(
"the config '{}' does not match the task_mode '{}'".format(self.config_path, self.task_mode))
return config_dict
def __getitem__(self, key):
return self._config_dict[key]
def __setitem__(self, key, value):
self._config_dict[key] = value
def print_config(self):
"""
Print Config
"""
for arg, value in sorted(six.iteritems(self._config_dict)):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
{
"net": {
"module_name": "bow",
"class_name": "BOW",
"emb_dim": 128,
"bow_dim": 128,
"hidden_dim": 128
},
"loss": {
"module_name": "hinge_loss",
"class_name": "HingeLoss",
"margin": 0.1
},
"optimizer": {
"class_name": "AdamOptimizer",
"learning_rate": 0.001,
"beta1": 0.9,
"beta2": 0.999,
"epsilon": 1e-08
},
"task_mode": "pairwise",
"model_path": "bow_pairwise"
}
{
"net": {
"module_name": "bow",
"class_name": "BOW",
"emb_dim": 128,
"bow_dim": 128
},
"loss": {
"module_name": "softmax_cross_entropy_loss",
"class_name": "SoftmaxCrossEntropyLoss"
},
"optimizer": {
"class_name": "AdamOptimizer",
"learning_rate": 0.001,
"beta1": 0.9,
"beta2": 0.999,
"epsilon": 1e-08
},
"task_mode": "pointwise",
"model_path": "bow_pointwise"
}
{
"net": {
"module_name": "cnn",
"class_name": "CNN",
"emb_dim": 128,
"filter_size": 3,
"num_filters": 256,
"hidden_dim": 128
},
"loss": {
"module_name": "hinge_loss",
"class_name": "HingeLoss",
"margin": 0.1
},
"optimizer": {
"class_name": "AdamOptimizer",
"learning_rate": 0.2,
"beta1": 0.9,
"beta2": 0.999,
"epsilon": 1e-08
},
"task_mode": "pairwise",
"model_path": "cnn_pairwise"
}
{
"net": {
"module_name": "cnn",
"class_name": "CNN",
"emb_dim": 128,
"filter_size": 3,
"num_filters": 256,
"hidden_dim": 128
},
"loss": {
"module_name": "softmax_cross_entropy_loss",
"class_name": "SoftmaxCrossEntropyLoss"
},
"optimizer": {
"class_name": "AdamOptimizer",
"learning_rate": 0.001,
"beta1": 0.9,
"beta2": 0.999,
"epsilon": 1e-08
},
"task_mode": "pointwise",
"model_path": "cnn_pointwise"
}
{
"net": {
"module_name": "gru",
"class_name": "GRU",
"emb_dim": 128,
"gru_dim": 128,
"hidden_dim": 128
},
"loss": {
"module_name": "hinge_loss",
"class_name": "HingeLoss",
"margin": 0.1
},
"optimizer": {
"class_name": "AdamOptimizer",
"learning_rate": 0.001,
"beta1": 0.9,
"beta2": 0.999,
"epsilon": 1e-08
},
"task_mode": "pairwise",
"model_path": "gru_pairwise"
}
{
"net": {
"module_name": "gru",
"class_name": "GRU",
"emb_dim": 128,
"gru_dim": 128,
"hidden_dim": 128
},
"loss": {
"module_name": "softmax_cross_entropy_loss",
"class_name": "SoftmaxCrossEntropyLoss"
},
"optimizer": {
"class_name": "AdamOptimizer",
"learning_rate" : 0.001,
"beta1": 0.9,
"beta2": 0.999,
"epsilon": 1e-08
},
"task_mode": "pointwise",
"model_path": "gru_pointwise"
}
{
"net": {
"module_name": "lstm",
"class_name": "LSTM",
"emb_dim": 128,
"lstm_dim": 128,
"hidden_dim": 128
},
"loss": {
"module_name": "hinge_loss",
"class_name": "HingeLoss",
"margin": 0.1
},
"optimizer": {
"class_name": "AdamOptimizer",
"learning_rate": 0.001,
"beta1": 0.9,
"beta2": 0.999,
"epsilon": 1e-08
},
"task_mode": "pairwise",
"model_path": "lstm_pairwise"
}
{
"net": {
"module_name": "lstm",
"class_name": "LSTM",
"emb_dim": 128,
"lstm_dim": 128,
"hidden_dim": 128
},
"loss": {
"module_name": "softmax_cross_entropy_loss",
"class_name": "SoftmaxCrossEntropyLoss"
},
"optimizer": {
"class_name": "AdamOptimizer",
"learning_rate": 0.001,
"beta1": 0.9,
"beta2": 0.999,
"epsilon": 1e-08
},
"task_mode": "pointwise",
"model_path": "lstm_pointwise"
}
{
"net": {
"module_name": "mm_dnn",
"class_name": "MMDNN",
"embedding_dim": 128,
"num_filters": 256,
"lstm_dim": 128,
"hidden_size": 128,
"window_size_left": 3,
"window_size_right": 3,
"dpool_size_left": 2,
"dpool_size_right": 2
},
"loss": {
"module_name": "softmax_cross_entropy_loss",
"class_name": "SoftmaxCrossEntropyLoss"
},
"optimizer": {
"class_name": "AdamOptimizer",
"learning_rate": 0.001,
"beta1": 0.9,
"beta2": 0.999,
"epsilon": 1e-08
},
"max_len_left": 32,
"max_len_right": 32,
"n_class": 2,
"task_mode": "pointwise",
"match_mask" : 1,
"model_path": "mm_dnn_pointwise"
}
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Download script, download dataset and pretrain models.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import sys
import time
import hashlib
import tarfile
import requests
def usage():
desc = ("\nDownload datasets and pretrained models for SimilarityNet task.\n"
"Usage:\n"
" 1. python download.py dataset\n"
" 2. python download.py model\n")
print(desc)
def md5file(fname):
hash_md5 = hashlib.md5()
with io.open(fname, "rb") as fin:
for chunk in iter(lambda: fin.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def extract(fname, dir_path):
"""
Extract tar.gz file
"""
try:
tar = tarfile.open(fname, "r:gz")
file_names = tar.getnames()
for file_name in file_names:
tar.extract(file_name, dir_path)
print(file_name)
tar.close()
except Exception as e:
raise e
def download(url, filename, md5sum):
"""
Download file and check md5
"""
retry = 0
retry_limit = 3
chunk_size = 4096
while not (os.path.exists(filename) and md5file(filename) == md5sum):
if retry < retry_limit:
retry += 1
else:
raise RuntimeError("Cannot download dataset ({0}) with retry {1} times.".
format(url, retry_limit))
try:
start = time.time()
size = 0
res = requests.get(url, stream=True)
filesize = int(res.headers['content-length'])
if res.status_code == 200:
print("[Filesize]: %0.2f MB" % (filesize / 1024 / 1024))
# save by chunk
with io.open(filename, "wb") as fout:
for chunk in res.iter_content(chunk_size=chunk_size):
if chunk:
fout.write(chunk)
size += len(chunk)
pr = '>' * int(size * 50 / filesize)
print('\r[Process ]: %s%.2f%%' % (pr, float(size / filesize*100)), end='')
end = time.time()
print("\n[CostTime]: %.2f s" % (end - start))
except Exception as e:
print(e)
def download_dataset(dir_path):
BASE_URL = "https://baidu-nlp.bj.bcebos.com/"
DATASET_NAME = "simnet_dataset-1.0.0.tar.gz"
DATASET_MD5 = "ec65b313bc237150ef536a8d26f3c73b"
file_path = os.path.join(dir_path, DATASET_NAME)
url = BASE_URL + DATASET_NAME
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# download dataset
print("Downloading dataset: %s" % url)
download(url, file_path, DATASET_MD5)
# extract dataset
print("Extracting dataset: %s" % file_path)
extract(file_path, dir_path)
os.remove(file_path)
def download_model(dir_path):
MODELS = {}
BASE_URL = "https://baidu-nlp.bj.bcebos.com/"
CNN_NAME = "simnet_bow-pairwise-1.0.0.tar.gz"
CNN_MD5 = "199a3f3af31558edcc71c3b54ea5e129"
MODELS[CNN_NAME] = CNN_MD5
if not os.path.exists(dir_path):
os.makedirs(dir_path)
for model in MODELS:
url = BASE_URL + model
model_path = os.path.join(dir_path, model)
print("Downloading model: %s" % url)
# download model
download(url, model_path, MODELS[model])
# extract model.tar.gz
print("Extracting model: %s" % model_path)
extract(model_path, dir_path)
os.remove(model_path)
if __name__ == '__main__':
if len(sys.argv) != 2:
usage()
sys.exit(1)
if sys.argv[1] == "dataset":
pwd = os.path.join(os.path.dirname(__file__), './')
download_dataset(pwd)
elif sys.argv[1] == "model":
pwd = os.path.join(os.path.dirname(__file__), './model_files')
download_model(pwd)
else:
usage()
#get data
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/simnet_dataset-1.0.0.tar.gz
tar xzf simnet_dataset-1.0.0.tar.gz
rm simnet_dataset-1.0.0.tar.gz
#!/usr/bin/env bash
model_files_path="./model_files"
#get pretrained_bow_pairwise_model
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/simnet_bow-pairwise-1.0.0.tar.gz
if [ ! -d $model_files_path ]; then
mkdir $model_files_path
fi
tar xzf simnet_bow-pairwise-1.0.0.tar.gz -C $model_files_path
rm simnet_bow-pairwise-1.0.0.tar.gz
\ No newline at end of file
#!/usr/bin/env bash
export FLAGS_enable_parallel_graph=1
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=3
export FLAGS_fraction_of_gpu_memory_to_use=0.95
TASK_NAME='simnet'
TEST_DATA_PATH=./data/ecom
VOCAB_PATH=./data/term2id.dict
CKPT_PATH=./model_files
TEST_RESULT_PATH=./evaluate/ecom_test_result
TASK_MODE='pairwise'
CONFIG_PATH=./config/bow_pairwise.json
INIT_CHECKPOINT=./model_files/simnet_bow_pairwise_pretrained_model/
cd ..
python ./run_classifier.py \
--task_name ${TASK_NAME} \
--use_cuda false \
--do_test True \
--verbose_result True \
--batch_size 128 \
--test_data_dir ${TEST_DATA_PATH} \
--test_result_path ${TEST_RESULT_PATH} \
--config_path ${CONFIG_PATH} \
--vocab_path ${VOCAB_PATH} \
--task_mode ${TASK_MODE} \
--init_checkpoint ${INIT_CHECKPOINT}
#!/usr/bin/env bash
export FLAGS_enable_parallel_graph=1
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=3
export FLAGS_fraction_of_gpu_memory_to_use=0.95
TASK_NAME='simnet'
TEST_DATA_PATH=./data/qqsim
VOCAB_PATH=./data/term2id.dict
CKPT_PATH=./model_files
TEST_RESULT_PATH=./evaluate/qqsim_test_result
TASK_MODE='pairwise'
CONFIG_PATH=./config/bow_pairwise.json
INIT_CHECKPOINT=./model_files/simnet_bow_pairwise_pretrained_model/
cd ..
python ./run_classifier.py \
--task_name ${TASK_NAME} \
--use_cuda false \
--do_test True \
--verbose_result True \
--batch_size 128 \
--test_data_dir ${TEST_DATA_PATH} \
--test_result_path ${TEST_RESULT_PATH} \
--config_path ${CONFIG_PATH} \
--vocab_path ${VOCAB_PATH} \
--task_mode ${TASK_MODE} \
--init_checkpoint ${INIT_CHECKPOINT}
#!/usr/bin/env bash
export FLAGS_enable_parallel_graph=1
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=3
export FLAGS_fraction_of_gpu_memory_to_use=0.95
TASK_NAME='simnet'
INFER_DATA_PATH=./evaluate/unicom_infer
VOCAB_PATH=./data/term2id.dict
CKPT_PATH=./model_files
INFER_RESULT_PATH=./evaluate/unicom_infer_result
TASK_MODE='pairwise'
CONFIG_PATH=./config/bow_pairwise.json
INIT_CHECKPOINT=./model_files/simnet_bow_pairwise_pretrained_model/
python unicom_split.py
cd ..
python ./run_classifier.py \
--task_name ${TASK_NAME} \
--use_cuda false \
--do_infer True \
--batch_size 128 \
--infer_data_dir ${INFER_DATA_PATH} \
--infer_result_path ${INFER_RESULT_PATH} \
--config_path ${CONFIG_PATH} \
--vocab_path ${VOCAB_PATH} \
--task_mode ${TASK_MODE} \
--init_checkpoint ${INIT_CHECKPOINT}
cd evaluate
python unicom_compute_pos_neg.py
#!/usr/bin/env bash
export FLAGS_enable_parallel_graph=1
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=3
export FLAGS_fraction_of_gpu_memory_to_use=0.95
TASK_NAME='simnet'
TEST_DATA_PATH=./data/zhidao
VOCAB_PATH=./data/term2id.dict
CKPT_PATH=./model_files
TEST_RESULT_PATH=./evaluate/zhidao_test_result
TASK_MODE='pairwise'
CONFIG_PATH=./config/bow_pairwise.json
INIT_CHECKPOINT=./model_files/simnet_bow_pairwise_pretrained_model/
cd ..
python ./run_classifier.py \
--task_name ${TASK_NAME} \
--use_cuda false \
--do_test True \
--verbose_result True \
--batch_size 128 \
--test_data_dir ${TEST_DATA_PATH} \
--test_result_path ${TEST_RESULT_PATH} \
--config_path ${CONFIG_PATH} \
--vocab_path ${VOCAB_PATH} \
--task_mode ${TASK_MODE} \
--init_checkpoint ${INIT_CHECKPOINT}
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
comput unicom
"""
import io
infer_results = []
labels = []
result = []
temp_reuslt = []
temp_query = ""
pos_num = 0.0
neg_num = 0.0
with io.open("./unicom_infer_result", "r", encoding="utf8") as infer_result_file:
for line in infer_result_file:
infer_results.append(line.strip().split("\t"))
with io.open("./unicom_label", "r", encoding="utf8") as label_file:
for line in label_file:
labels.append(line.strip().split("\t"))
for infer_result, label in zip(infer_results, labels):
if infer_result[0] != temp_query and temp_query != "":
result.append(temp_reuslt)
temp_query = infer_result[0]
temp_reuslt = []
temp_reuslt.append(infer_result + label)
else:
if temp_query == '':
temp_query = infer_result[0]
temp_reuslt.append(infer_result + label)
else:
result.append(temp_reuslt)
for _result in result:
for n, i in enumerate(_result, start=1):
for j in _result[n:]:
if (int(j[-1]) > int(i[-1]) and float(j[-2]) < float(i[-2])) or (
int(j[-1]) < int(i[-1]) and float(j[-2]) > float(i[-2])):
neg_num += 1
elif (int(j[-1]) > int(i[-1]) and float(j[-2]) > float(i[-2])) or (
int(j[-1]) < int(i[-1]) and float(j[-2]) < float(i[-2])):
pos_num += 1
print("pos/neg of unicom data is %f" % (pos_num / neg_num))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
split unicom file
"""
import io
with io.open("../data/unicom", "r", encoding="utf8") as unicom_file:
with io.open("./unicom_infer", "w", encoding="utf8") as infer_file:
with io.open("./unicom_label", "w", encoding="utf8") as label_file:
for line in unicom_file:
line = line.strip().split('\t')
infer_file.write("\t".join(line[:2]) + '\n')
label_file.write(line[2] + '\n')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MMDNN class
"""
import numpy as np
import paddle.fluid as fluid
import logging
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, to_variable, Layer, guard
from paddle.fluid.dygraph.nn import Conv2D
import paddle_layers as pd_layers
from paddle.fluid import layers
from paddle.fluid.dygraph import Layer
class BasicLSTMUnit(Layer):
"""
****
BasicLSTMUnit class, Using basic operator to build LSTM
The algorithm can be described as the code below.
.. math::
i_t &= \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + b_i)
f_t &= \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + b_f + forget_bias )
o_t &= \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + b_o)
\\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
c_t &= f_t \odot c_{t-1} + i_t \odot \\tilde{c_t}
h_t &= o_t \odot tanh(c_t)
- $W$ terms denote weight matrices (e.g. $W_{ix}$ is the matrix
of weights from the input gate to the input)
- The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector).
- sigmoid is the logistic sigmoid function.
- $i, f, o$ and $c$ are the input gate, forget gate, output gate,
and cell activation vectors, respectively, all of which have the same size as
the cell output activation vector $h$.
- The :math:`\odot` is the element-wise product of the vectors.
- :math:`tanh` is the activation functions.
- :math:`\\tilde{c_t}` is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.
Args:
name_scope(string) : The name scope used to identify parameter and bias name
hidden_size (integer): The hidden size used in the Unit.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight matrix. Note:
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The parameter attribute for the bias
of LSTM unit.
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized as zero. Default: None.
gate_activation (function|None): The activation function for gates (actGate).
Default: 'fluid.layers.sigmoid'
activation (function|None): The activation function for cells (actNode).
Default: 'fluid.layers.tanh'
forget_bias(float|1.0): forget bias used when computing forget gate
dtype(string): data type used in this unit
"""
def __init__(self,
hidden_size,
input_size,
param_attr=None,
bias_attr=None,
gate_activation=None,
activation=None,
forget_bias=1.0,
dtype='float32'):
super(BasicLSTMUnit, self).__init__(dtype)
self._hiden_size = hidden_size
self._param_attr = param_attr
self._bias_attr = bias_attr
self._gate_activation = gate_activation or layers.sigmoid
self._activation = activation or layers.tanh
self._forget_bias = layers.fill_constant(
[1], dtype=dtype, value=forget_bias)
self._forget_bias.stop_gradient = False
self._dtype = dtype
self._input_size = input_size
self._weight = self.create_parameter(
attr=self._param_attr,
shape=[self._input_size + self._hiden_size, 4 * self._hiden_size],
dtype=self._dtype)
self._bias = self.create_parameter(
attr=self._bias_attr,
shape=[4 * self._hiden_size],
dtype=self._dtype,
is_bias=True)
def forward(self, input, pre_hidden, pre_cell):
concat_input_hidden = layers.concat([input, pre_hidden], 1)
gate_input = layers.matmul(x=concat_input_hidden, y=self._weight)
gate_input = layers.elementwise_add(gate_input, self._bias)
i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
new_cell = layers.elementwise_add(
layers.elementwise_mul(
pre_cell,
layers.sigmoid(layers.elementwise_add(f, self._forget_bias))),
layers.elementwise_mul(layers.sigmoid(i), layers.tanh(j)))
new_hidden = layers.tanh(new_cell) * layers.sigmoid(o)
return new_hidden, new_cell
class MMDNN(object):
"""
MMDNN
"""
def __init__(self, config):
"""
initialize
"""
self.vocab_size = int(config['dict_size'])
self.emb_size = int(config['net']['embedding_dim'])
self.lstm_dim = int(config['net']['lstm_dim'])
self.kernel_size = int(config['net']['num_filters'])
self.win_size1 = int(config['net']['window_size_left'])
self.win_size2 = int(config['net']['window_size_right'])
self.dpool_size1 = int(config['net']['dpool_size_left'])
self.dpool_size2 = int(config['net']['dpool_size_right'])
self.hidden_size = int(config['net']['hidden_size'])
self.seq_len1 = int(config['max_len_left'])
self.seq_len2 = int(config['max_len_right'])
self.task_mode = config['task_mode']
if int(config['match_mask']) != 0:
self.match_mask = True
else:
self.match_mask = False
if self.task_mode == "pointwise":
self.n_class = int(config['n_class'])
self.out_size = self.n_class
elif self.task_mode == "pairwise":
self.out_size = 1
else:
logging.error("training mode not supported")
def embedding_layer(self, input, zero_pad=True, scale=True):
"""
embedding layer
"""
emb = Embedding(
size=[self.vocab_size, self.emb_size],
padding_idx=(0 if zero_pad else None),
param_attr=fluid.ParamAttr(
name="word_embedding", initializer=fluid.initializer.Xavier()))
emb = emb(input)
if scale:
emb = emb * (self.emb_size**0.5)
return emb
def bi_dynamic_lstm(self, input, hidden_size):
"""
bi_lstm layer
"""
fw_in_proj = Linear(
input_dim=self.emb_size,
output_dim=4 * hidden_size,
param_attr=fluid.ParamAttr(name="fw_fc.w"),
bias_attr=False)
fw_in_proj = fw_in_proj(input)
forward = pd_layers.DynamicLSTMLayer(
size=4 * hidden_size,
is_reverse=False,
param_attr=fluid.ParamAttr(name="forward_lstm.w"),
bias_attr=fluid.ParamAttr(name="forward_lstm.b")).ops()
forward = forward(fw_in_proj)
rv_in_proj = Linear(
input_dim=self.emb_size,
output_dim=4 * hidden_size,
param_attr=fluid.ParamAttr(name="rv_fc.w"),
bias_attr=False)
rv_in_proj = rv_in_proj(input)
reverse = pd_layers.DynamicLSTMLayer(
4 * hidden_size,
'lstm'
is_reverse=True,
param_attr=fluid.ParamAttr(name="reverse_lstm.w"),
bias_attr=fluid.ParamAttr(name="reverse_lstm.b")).ops()
reverse = reverse(rv_in_proj)
return [forward, reverse]
def conv_pool_relu_layer(self, input, mask=None):
"""
convolution and pool layer
"""
# data format NCHW
emb_expanded = fluid.layers.unsqueeze(input=input, axes=[1])
# same padding
conv = Conv2d(
num_filters=self.kernel_size,
stride=1,
padding=(int(self.seq_len1 / 2), int(self.seq_len2 // 2)),
filter_size=(self.seq_len1, self.seq_len2),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.1)))
conv = conv(emb_expanded)
if mask is not None:
cross_mask = fluid.layers.stack(x=[mask] * self.kernel_size, axis=1)
conv = cross_mask * conv + (1 - cross_mask) * (-2**32 + 1)
# valid padding
pool = fluid.layers.pool2d(
input=conv,
pool_size=[
int(self.seq_len1 / self.dpool_size1),
int(self.seq_len2 / self.dpool_size2)
],
pool_stride=[
int(self.seq_len1 / self.dpool_size1),
int(self.seq_len2 / self.dpool_size2)
],
pool_type="max", )
relu = fluid.layers.relu(pool)
return relu
def get_cross_mask(self, left_lens, right_lens):
"""
cross mask
"""
mask1 = fluid.layers.sequence_mask(
x=left_lens, dtype='float32', maxlen=self.seq_len1 + 1)
mask2 = fluid.layers.sequence_mask(
x=right_lens, dtype='float32', maxlen=self.seq_len2 + 1)
mask1 = fluid.layers.transpose(x=mask1, perm=[0, 2, 1])
cross_mask = fluid.layers.matmul(x=mask1, y=mask2)
return cross_mask
def predict(self, left, right):
"""
Forward network
"""
left_emb = self.embedding_layer(left, zero_pad=True, scale=False)
right_emb = self.embedding_layer(right, zero_pad=True, scale=False)
bi_left_outputs = self.bi_dynamic_lstm(
input=left_emb, hidden_size=self.lstm_dim)
left_seq_encoder = fluid.layers.concat(input=bi_left_outputs, axis=1)
bi_right_outputs = self.bi_dynamic_lstm(
input=right_emb, hidden_size=self.lstm_dim)
right_seq_encoder = fluid.layers.concat(input=bi_right_outputs, axis=1)
pad_value = fluid.layers.assign(input=np.array([0]).astype("float32"))
left_seq_encoder, left_lens = fluid.layers.sequence_pad(
x=left_seq_encoder, pad_value=pad_value, maxlen=self.seq_len1)
right_seq_encoder, right_lens = fluid.layers.sequence_pad(
x=right_seq_encoder, pad_value=pad_value, maxlen=self.seq_len2)
cross = fluid.layers.matmul(
left_seq_encoder, right_seq_encoder, transpose_y=True)
if self.match_mask:
cross_mask = self.get_cross_mask(left_lens, right_lens)
else:
cross_mask = None
conv_pool_relu = self.conv_pool_relu_layer(input=cross, mask=cross_mask)
relu_hid1 = Linear(
input_dim=conv_pool_relu.shape[-1],
output_dim=self.hidden_size)
relu_hid1 = relu_hid1(conv_pool_relu)
relu_hid1 = fluid.layers.tanh(relu_hid1)
relu_hid1 = Linear(
input_dim=relu_hid1.shape[-1],
output_dim=self.out_size)
pred = relu_hid1(pred)
pred = fluid.layers.softmax(pred)
return left_seq_encoder, pred
#encoding=utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import paddle
import paddle.fluid as fluid
def check_cuda(use_cuda, err = \
"\nYou can not set use_cuda = True in the model because you are using paddlepaddle-cpu.\n \
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.\n"
):
"""
Log error and exit when set use_gpu=true in paddlepaddle
cpu version.
"""
try:
if use_cuda == True and fluid.is_compiled_with_cuda() == False:
print(err)
sys.exit(1)
except Exception as e:
pass
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
print(err)
sys.exit(1)
def check_version():
"""
Log error and exit when the installed version of paddlepaddle is
not satisfied.
"""
err = "PaddlePaddle version 1.6 or higher is required, " \
"or a suitable develop version is satisfied as well. \n" \
"Please make sure the version is good with your code." \
try:
fluid.require_version('1.6.0')
except Exception as e:
print(err)
sys.exit(1)
if __name__ == "__main__":
check_cuda(True)
check_cuda(False)
check_cuda(True, "This is only for testing.")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
base layers
"""
from paddle.fluid import layers
from paddle.fluid.dygraph import Layer
from paddle.fluid.dygraph import GRUUnit
from paddle.fluid.dygraph.base import to_variable
# import numpy as np
# import logging
class DynamicGRU(Layer):
def __init__(self,
size,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
h_0=None,
origin_mode=False,
init_size = None):
super(DynamicGRU, self).__init__()
self.gru_unit = GRUUnit(
size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
self.size = size
self.h_0 = h_0
self.is_reverse = is_reverse
def forward(self, inputs):
hidden = self.h_0
res = []
for i in range(inputs.shape[1]):
if self.is_reverse:
i = inputs.shape[1] - 1 - i
input_ = inputs[ :, i:i+1, :]
input_ = fluid.layers.reshape(input_, [-1, input_.shape[2]], inplace=False)
hidden, reset, gate = self.gru_unit(input_, hidden)
hidden_ = fluid.layers.reshape(hidden, [-1, 1, hidden.shape[1]], inplace=False)
res.append(hidden_)
if self.is_reverse:
res = res[::-1]
res = fluid.layers.concat(res, axis=1)
return res
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
bow class
"""
import paddle_layers as layers
from paddle import fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph import Layer, Embedding
import paddle.fluid.param_attr as attr
uniform_initializer = lambda x: fluid.initializer.UniformInitializer(low=-x, high=x)
class BOW(Layer):
"""
BOW
"""
def __init__(self, conf_dict):
"""
initialize
"""
super(BOW, self).__init__()
self.dict_size = conf_dict["dict_size"]
self.task_mode = conf_dict["task_mode"]
self.emb_dim = conf_dict["net"]["emb_dim"]
self.bow_dim = conf_dict["net"]["bow_dim"]
self.seq_len = 5
self.emb_layer = layers.EmbeddingLayer(self.dict_size, self.emb_dim, "emb").ops()
self.bow_layer = layers.FCLayer(self.bow_dim, None, "fc").ops()
self.softmax_layer = layers.FCLayer(2, "softmax", "cos_sim").ops()
def forward(self, left, right):
"""
Forward network
"""
# embedding layer
left_emb = self.emb_layer(left)
right_emb = self.emb_layer(right)
left_emb = fluid.layers.reshape(
left_emb, shape=[-1, self.seq_len, self.bow_dim])
right_emb = fluid.layers.reshape(
right_emb, shape=[-1, self.seq_len, self.bow_dim])
bow_left = fluid.layers.reduce_sum(left_emb, dim=1)
bow_right = fluid.layers.reduce_sum(right_emb, dim=1)
softsign_layer = layers.SoftsignLayer()
left_soft = softsign_layer.ops(bow_left)
right_soft = softsign_layer.ops(bow_right)
# matching layer
if self.task_mode == "pairwise":
left_bow = self.bow_layer(left_soft)
right_bow = self.bow_layer(right_soft)
cos_sim_layer = layers.CosSimLayer()
pred = cos_sim_layer.ops(left_bow, right_bow)
return left_bow, pred
else:
concat_layer = layers.ConcatLayer(1)
concat = concat_layer.ops([left_soft, right_soft])
concat_fc = self.bow_layer(concat)
pred = self.softmax_layer(concat_fc)
return left_soft, pred
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
cnn class
"""
import paddle_layers as layers
from paddle import fluid
from paddle.fluid.dygraph import Layer
class CNN(Layer):
"""
CNN
"""
def __init__(self, conf_dict):
"""
initialize
"""
super(CNN, self).__init__()
self.dict_size = conf_dict["dict_size"]
self.task_mode = conf_dict["task_mode"]
self.emb_dim = conf_dict["net"]["emb_dim"]
self.filter_size = conf_dict["net"]["filter_size"]
self.num_filters = conf_dict["net"]["num_filters"]
self.hidden_dim = conf_dict["net"]["hidden_dim"]
self.seq_len = 5
self.channels = 1
# layers
self.emb_layer = layers.EmbeddingLayer(self.dict_size, self.emb_dim, "emb").ops()
self.fc_layer = layers.FCLayer(self.hidden_dim, None, "fc").ops()
self.softmax_layer = layers.FCLayer(2, "softmax", "cos_sim").ops()
self.cnn_layer = layers.SimpleConvPool(
self.channels,
self.num_filters,
self.filter_size)
def forward(self, left, right):
"""
Forward network
"""
# embedding layer
left_emb = self.emb_layer(left)
right_emb = self.emb_layer(right)
# Presentation context
left_emb = fluid.layers.reshape(
left_emb, shape=[-1, self.channels, self.seq_len, self.hidden_dim])
right_emb = fluid.layers.reshape(
right_emb, shape=[-1, self.channels, self.seq_len, self.hidden_dim])
left_cnn = self.cnn_layer(left_emb)
right_cnn = self.cnn_layer(right_emb)
# matching layer
if self.task_mode == "pairwise":
left_fc = self.fc_layer(left_cnn)
right_fc = self.fc_layer(right_cnn)
cos_sim_layer = layers.CosSimLayer()
pred = cos_sim_layer.ops(left_fc, right_fc)
return left_fc, pred
else:
concat_layer = layers.ConcatLayer(1)
concat = concat_layer.ops([left_cnn, right_cnn])
concat_fc = self.fc_layer(concat)
pred = self.softmax_layer(concat_fc)
return left_cnn, pred
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid import layers, unique_name
from paddle.fluid.dygraph import Layer
from paddle.fluid.dygraph.layer_object_helper import LayerObjectHelper
from paddle.fluid.layers.control_flow import StaticRNN
__all__ = ['BasicGRUUnit', 'basic_gru', 'BasicLSTMUnit', 'basic_lstm']
class BasicGRUUnit(Layer):
"""
****
BasicGRUUnit class, using basic operators to build GRU
The algorithm can be described as the equations below.
.. math::
u_t & = actGate(W_ux xu_{t} + W_uh h_{t-1} + b_u)
r_t & = actGate(W_rx xr_{t} + W_rh h_{t-1} + b_r)
m_t & = actNode(W_cx xm_t + W_ch dot(r_t, h_{t-1}) + b_m)
h_t & = dot(u_t, h_{t-1}) + dot((1-u_t), m_t)
Args:
name_scope(string) : The name scope used to identify parameters and biases
hidden_size (integer): The hidden size used in the Unit.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight matrix. Note:
If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The parameter attribute for the bias
of GRU unit.
If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
gate_activation (function|None): The activation function for gates (actGate).
Default: 'fluid.layers.sigmoid'
activation (function|None): The activation function for cell (actNode).
Default: 'fluid.layers.tanh'
dtype(string): data type used in this unit
Examples:
.. code-block:: python
import paddle.fluid.layers as layers
from paddle.fluid.contrib.layers import BasicGRUUnit
input_size = 128
hidden_size = 256
input = layers.data( name = "input", shape = [-1, input_size], dtype='float32')
pre_hidden = layers.data( name = "pre_hidden", shape=[-1, hidden_size], dtype='float32')
gru_unit = BasicGRUUnit( "gru_unit", hidden_size )
new_hidden = gru_unit( input, pre_hidden )
"""
def __init__(self,
name_scope,
hidden_size,
param_attr=None,
bias_attr=None,
gate_activation=None,
activation=None,
dtype='float32'):
super(BasicGRUUnit, self).__init__(name_scope, dtype)
# reserve old school _full_name and _helper for static graph save load
self._full_name = unique_name.generate(name_scope + "/" +
self.__class__.__name__)
self._helper = LayerObjectHelper(self._full_name)
self._name = name_scope
self._hiden_size = hidden_size
self._param_attr = param_attr
self._bias_attr = bias_attr
self._gate_activation = gate_activation or layers.sigmoid
self._activation = activation or layers.tanh
self._dtype = dtype
def _build_once(self, input, pre_hidden):
self._input_size = input.shape[-1]
assert (self._input_size > 0)
self._gate_weight = self.create_parameter(
attr=self._param_attr,
shape=[self._input_size + self._hiden_size, 2 * self._hiden_size],
dtype=self._dtype)
self._candidate_weight = self.create_parameter(
attr=self._param_attr,
shape=[self._input_size + self._hiden_size, self._hiden_size],
dtype=self._dtype)
self._gate_bias = self.create_parameter(
attr=self._bias_attr,
shape=[2 * self._hiden_size],
dtype=self._dtype,
is_bias=True)
self._candidate_bias = self.create_parameter(
attr=self._bias_attr,
shape=[self._hiden_size],
dtype=self._dtype,
is_bias=True)
def forward(self, input, pre_hidden):
concat_input_hidden = layers.concat([input, pre_hidden], 1)
gate_input = layers.matmul(x=concat_input_hidden, y=self._gate_weight)
gate_input = layers.elementwise_add(gate_input, self._gate_bias)
gate_input = self._gate_activation(gate_input)
r, u = layers.split(gate_input, num_or_sections=2, dim=1)
r_hidden = r * pre_hidden
candidate = layers.matmul(
layers.concat([input, r_hidden], 1), self._candidate_weight)
candidate = layers.elementwise_add(candidate, self._candidate_bias)
c = self._activation(candidate)
new_hidden = u * pre_hidden + (1 - u) * c
return new_hidden
def basic_gru(input,
init_hidden,
hidden_size,
num_layers=1,
sequence_length=None,
dropout_prob=0.0,
bidirectional=False,
batch_first=True,
param_attr=None,
bias_attr=None,
gate_activation=None,
activation=None,
dtype='float32',
name='basic_gru'):
"""
GRU implementation using basic operator, supports multiple layers and bidirection gru.
.. math::
u_t & = actGate(W_ux xu_{t} + W_uh h_{t-1} + b_u)
r_t & = actGate(W_rx xr_{t} + W_rh h_{t-1} + b_r)
m_t & = actNode(W_cx xm_t + W_ch dot(r_t, h_{t-1}) + b_m)
h_t & = dot(u_t, h_{t-1}) + dot((1-u_t), m_t)
Args:
input (Variable): GRU input tensor,
if batch_first = False, shape should be ( seq_len x batch_size x input_size )
if batch_first = True, shape should be ( batch_size x seq_len x hidden_size )
init_hidden(Variable|None): The initial hidden state of the GRU
This is a tensor with shape ( num_layers x batch_size x hidden_size)
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
and can be reshaped to tensor with ( num_layers x 2 x batch_size x hidden_size) to use.
If it's None, it will be set to all 0.
hidden_size (int): Hidden size of the GRU
num_layers (int): The total number of layers of the GRU
sequence_length (Variabe|None): A Tensor (shape [batch_size]) stores each real length of each instance,
This tensor will be convert to a mask to mask the padding ids
If it's None means NO padding ids
dropout_prob(float|0.0): Dropout prob, dropout ONLY works after rnn output of earch layers,
NOT between time steps
bidirectional (bool|False): If it is bidirectional
batch_first (bool|True): The shape format of the input and output tensors. If true,
the shape format should be :attr:`[batch_size, seq_len, hidden_size]`. If false,
the shape format should be :attr:`[seq_len, batch_size, hidden_size]`. By default
this function accepts input and emits output in batch-major form to be consistent
with most of data format, though a bit less efficient because of extra transposes.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight matrix. Note:
If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The parameter attribute for the bias
of GRU unit.
If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
gate_activation (function|None): The activation function for gates (actGate).
Default: 'fluid.layers.sigmoid'
activation (function|None): The activation function for cell (actNode).
Default: 'fluid.layers.tanh'
dtype(string): data type used in this unit
name(string): name used to identify parameters and biases
Returns:
rnn_out(Tensor),last_hidden(Tensor)
- rnn_out is result of GRU hidden, with shape (seq_len x batch_size x hidden_size) \
if is_bidirec set to True, shape will be ( seq_len x batch_sze x hidden_size*2)
- last_hidden is the hidden state of the last step of GRU \
shape is ( num_layers x batch_size x hidden_size ) \
if is_bidirec set to True, shape will be ( num_layers*2 x batch_size x hidden_size),
can be reshaped to a tensor with shape( num_layers x 2 x batch_size x hidden_size)
Examples:
.. code-block:: python
import paddle.fluid.layers as layers
from paddle.fluid.contrib.layers import basic_gru
batch_size = 20
input_size = 128
hidden_size = 256
num_layers = 2
dropout = 0.5
bidirectional = True
batch_first = False
input = layers.data( name = "input", shape = [-1, batch_size, input_size], dtype='float32')
pre_hidden = layers.data( name = "pre_hidden", shape=[-1, hidden_size], dtype='float32')
sequence_length = layers.data( name="sequence_length", shape=[-1], dtype='int32')
rnn_out, last_hidden = basic_gru( input, pre_hidden, hidden_size, num_layers = num_layers, \
sequence_length = sequence_length, dropout_prob=dropout, bidirectional = bidirectional, \
batch_first = batch_first)
"""
fw_unit_list = []
for i in range(num_layers):
new_name = name + "_layers_" + str(i)
fw_unit_list.append(
BasicGRUUnit(new_name, hidden_size, param_attr, bias_attr,
gate_activation, activation, dtype))
if bidirectional:
bw_unit_list = []
for i in range(num_layers):
new_name = name + "_reverse_layers_" + str(i)
bw_unit_list.append(
BasicGRUUnit(new_name, hidden_size, param_attr, bias_attr,
gate_activation, activation, dtype))
if batch_first:
input = layers.transpose(input, [1, 0, 2])
mask = None
if sequence_length:
max_seq_len = layers.shape(input)[0]
mask = layers.sequence_mask(
sequence_length, maxlen=max_seq_len, dtype='float32')
mask = layers.transpose(mask, [1, 0])
direc_num = 1
if bidirectional:
direc_num = 2
if init_hidden:
init_hidden = layers.reshape(
init_hidden, shape=[num_layers, direc_num, -1, hidden_size])
def get_single_direction_output(rnn_input,
unit_list,
mask=None,
direc_index=0):
rnn = StaticRNN()
with rnn.step():
step_input = rnn.step_input(rnn_input)
if mask:
step_mask = rnn.step_input(mask)
for i in range(num_layers):
if init_hidden:
pre_hidden = rnn.memory(init=init_hidden[i, direc_index])
else:
pre_hidden = rnn.memory(
batch_ref=rnn_input,
shape=[-1, hidden_size],
ref_batch_dim_idx=1)
new_hidden = unit_list[i](step_input, pre_hidden)
if mask:
new_hidden = layers.elementwise_mul(
new_hidden, step_mask, axis=0) - layers.elementwise_mul(
pre_hidden, (step_mask - 1), axis=0)
rnn.update_memory(pre_hidden, new_hidden)
rnn.step_output(new_hidden)
step_input = new_hidden
if dropout_prob != None and dropout_prob > 0.0:
step_input = layers.dropout(
step_input,
dropout_prob=dropout_prob, )
rnn.step_output(step_input)
rnn_out = rnn()
last_hidden_array = []
rnn_output = rnn_out[-1]
for i in range(num_layers):
last_hidden = rnn_out[i]
last_hidden = last_hidden[-1]
last_hidden_array.append(last_hidden)
last_hidden_output = layers.concat(last_hidden_array, axis=0)
last_hidden_output = layers.reshape(
last_hidden_output, shape=[num_layers, -1, hidden_size])
return rnn_output, last_hidden_output
# seq_len, batch_size, hidden_size
fw_rnn_out, fw_last_hidden = get_single_direction_output(
input, fw_unit_list, mask, direc_index=0)
if bidirectional:
bw_input = layers.reverse(input, axis=[0])
bw_mask = None
if mask:
bw_mask = layers.reverse(mask, axis=[0])
bw_rnn_out, bw_last_hidden = get_single_direction_output(
bw_input, bw_unit_list, bw_mask, direc_index=1)
bw_rnn_out = layers.reverse(bw_rnn_out, axis=[0])
rnn_out = layers.concat([fw_rnn_out, bw_rnn_out], axis=2)
last_hidden = layers.concat([fw_last_hidden, bw_last_hidden], axis=1)
last_hidden = layers.reshape(
last_hidden, shape=[num_layers * direc_num, -1, hidden_size])
if batch_first:
rnn_out = layers.transpose(rnn_out, [1, 0, 2])
return rnn_out, last_hidden
else:
rnn_out = fw_rnn_out
last_hidden = fw_last_hidden
if batch_first:
rnn_out = layers.transpose(rnn_out, [1, 0, 2])
return rnn_out, last_hidden
def basic_lstm(input,
init_hidden,
init_cell,
hidden_size,
num_layers=1,
sequence_length=None,
dropout_prob=0.0,
bidirectional=False,
batch_first=True,
param_attr=None,
bias_attr=None,
gate_activation=None,
activation=None,
forget_bias=1.0,
dtype='float32',
name='basic_lstm'):
"""
LSTM implementation using basic operators, supports multiple layers and bidirection LSTM.
.. math::
i_t &= \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + b_i)
f_t &= \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + b_f + forget_bias )
o_t &= \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + b_o)
\\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
c_t &= f_t \odot c_{t-1} + i_t \odot \\tilde{c_t}
h_t &= o_t \odot tanh(c_t)
Args:
input (Variable): lstm input tensor,
if batch_first = False, shape should be ( seq_len x batch_size x input_size )
if batch_first = True, shape should be ( batch_size x seq_len x hidden_size )
init_hidden(Variable|None): The initial hidden state of the LSTM
This is a tensor with shape ( num_layers x batch_size x hidden_size)
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
and can be reshaped to a tensor with shape ( num_layers x 2 x batch_size x hidden_size) to use.
If it's None, it will be set to all 0.
init_cell(Variable|None): The initial hidden state of the LSTM
This is a tensor with shape ( num_layers x batch_size x hidden_size)
if is_bidirec = True, shape should be ( num_layers*2 x batch_size x hidden_size)
and can be reshaped to a tensor with shape ( num_layers x 2 x batch_size x hidden_size) to use.
If it's None, it will be set to all 0.
hidden_size (int): Hidden size of the LSTM
num_layers (int): The total number of layers of the LSTM
sequence_length (Variabe|None): A tensor (shape [batch_size]) stores each real length of each instance,
This tensor will be convert to a mask to mask the padding ids
If it's None means NO padding ids
dropout_prob(float|0.0): Dropout prob, dropout ONLY work after rnn output of earch layers,
NOT between time steps
bidirectional (bool|False): If it is bidirectional
batch_first (bool|True): The shape format of the input and output tensors. If true,
the shape format should be :attr:`[batch_size, seq_len, hidden_size]`. If false,
the shape format should be :attr:`[seq_len, batch_size, hidden_size]`. By default
this function accepts input and emits output in batch-major form to be consistent
with most of data format, though a bit less efficient because of extra transposes.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight matrix. Note:
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The parameter attribute for the bias
of LSTM unit.
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
gate_activation (function|None): The activation function for gates (actGate).
Default: 'fluid.layers.sigmoid'
activation (function|None): The activation function for cell (actNode).
Default: 'fluid.layers.tanh'
forget_bias (float|1.0) : Forget bias used to compute the forget gate
dtype(string): Data type used in this unit
name(string): Name used to identify parameters and biases
Returns:
rnn_out(Tensor), last_hidden(Tensor), last_cell(Tensor)
- rnn_out is the result of LSTM hidden, shape is (seq_len x batch_size x hidden_size) \
if is_bidirec set to True, it's shape will be ( seq_len x batch_sze x hidden_size*2)
- last_hidden is the hidden state of the last step of LSTM \
with shape ( num_layers x batch_size x hidden_size ) \
if is_bidirec set to True, it's shape will be ( num_layers*2 x batch_size x hidden_size),
and can be reshaped to a tensor ( num_layers x 2 x batch_size x hidden_size) to use.
- last_cell is the hidden state of the last step of LSTM \
with shape ( num_layers x batch_size x hidden_size ) \
if is_bidirec set to True, it's shape will be ( num_layers*2 x batch_size x hidden_size),
and can be reshaped to a tensor ( num_layers x 2 x batch_size x hidden_size) to use.
Examples:
.. code-block:: python
import paddle.fluid.layers as layers
from paddle.fluid.contrib.layers import basic_lstm
batch_size = 20
input_size = 128
hidden_size = 256
num_layers = 2
dropout = 0.5
bidirectional = True
batch_first = False
input = layers.data( name = "input", shape = [-1, batch_size, input_size], dtype='float32')
pre_hidden = layers.data( name = "pre_hidden", shape=[-1, hidden_size], dtype='float32')
pre_cell = layers.data( name = "pre_cell", shape=[-1, hidden_size], dtype='float32')
sequence_length = layers.data( name="sequence_length", shape=[-1], dtype='int32')
rnn_out, last_hidden, last_cell = basic_lstm( input, pre_hidden, pre_cell, \
hidden_size, num_layers = num_layers, \
sequence_length = sequence_length, dropout_prob=dropout, bidirectional = bidirectional, \
batch_first = batch_first)
"""
fw_unit_list = []
for i in range(num_layers):
new_name = name + "_layers_" + str(i)
fw_unit_list.append(
BasicLSTMUnit(
new_name,
hidden_size,
param_attr=param_attr,
bias_attr=bias_attr,
gate_activation=gate_activation,
activation=activation,
forget_bias=forget_bias,
dtype=dtype))
if bidirectional:
bw_unit_list = []
for i in range(num_layers):
new_name = name + "_reverse_layers_" + str(i)
bw_unit_list.append(
BasicLSTMUnit(
new_name,
hidden_size,
param_attr=param_attr,
bias_attr=bias_attr,
gate_activation=gate_activation,
activation=activation,
forget_bias=forget_bias,
dtype=dtype))
if batch_first:
input = layers.transpose(input, [1, 0, 2])
mask = None
if sequence_length:
max_seq_len = layers.shape(input)[0]
mask = layers.sequence_mask(
sequence_length, maxlen=max_seq_len, dtype='float32')
mask = layers.transpose(mask, [1, 0])
direc_num = 1
if bidirectional:
direc_num = 2
# convert to [num_layers, 2, batch_size, hidden_size]
if init_hidden:
init_hidden = layers.reshape(
init_hidden, shape=[num_layers, direc_num, -1, hidden_size])
init_cell = layers.reshape(
init_cell, shape=[num_layers, direc_num, -1, hidden_size])
# forward direction
def get_single_direction_output(rnn_input,
unit_list,
mask=None,
direc_index=0):
rnn = StaticRNN()
with rnn.step():
step_input = rnn.step_input(rnn_input)
if mask:
step_mask = rnn.step_input(mask)
for i in range(num_layers):
if init_hidden:
pre_hidden = rnn.memory(init=init_hidden[i, direc_index])
pre_cell = rnn.memory(init=init_cell[i, direc_index])
else:
pre_hidden = rnn.memory(
batch_ref=rnn_input, shape=[-1, hidden_size])
pre_cell = rnn.memory(
batch_ref=rnn_input, shape=[-1, hidden_size])
new_hidden, new_cell = unit_list[i](step_input, pre_hidden,
pre_cell)
if mask:
new_hidden = layers.elementwise_mul(
new_hidden, step_mask, axis=0) - layers.elementwise_mul(
pre_hidden, (step_mask - 1), axis=0)
new_cell = layers.elementwise_mul(
new_cell, step_mask, axis=0) - layers.elementwise_mul(
pre_cell, (step_mask - 1), axis=0)
rnn.update_memory(pre_hidden, new_hidden)
rnn.update_memory(pre_cell, new_cell)
rnn.step_output(new_hidden)
rnn.step_output(new_cell)
step_input = new_hidden
if dropout_prob != None and dropout_prob > 0.0:
step_input = layers.dropout(
step_input,
dropout_prob=dropout_prob,
dropout_implementation='upscale_in_train')
rnn.step_output(step_input)
rnn_out = rnn()
last_hidden_array = []
last_cell_array = []
rnn_output = rnn_out[-1]
for i in range(num_layers):
last_hidden = rnn_out[i * 2]
last_hidden = last_hidden[-1]
last_hidden_array.append(last_hidden)
last_cell = rnn_out[i * 2 + 1]
last_cell = last_cell[-1]
last_cell_array.append(last_cell)
last_hidden_output = layers.concat(last_hidden_array, axis=0)
last_hidden_output = layers.reshape(
last_hidden_output, shape=[num_layers, -1, hidden_size])
last_cell_output = layers.concat(last_cell_array, axis=0)
last_cell_output = layers.reshape(
last_cell_output, shape=[num_layers, -1, hidden_size])
return rnn_output, last_hidden_output, last_cell_output
# seq_len, batch_size, hidden_size
fw_rnn_out, fw_last_hidden, fw_last_cell = get_single_direction_output(
input, fw_unit_list, mask, direc_index=0)
if bidirectional:
bw_input = layers.reverse(input, axis=[0])
bw_mask = None
if mask:
bw_mask = layers.reverse(mask, axis=[0])
bw_rnn_out, bw_last_hidden, bw_last_cell = get_single_direction_output(
bw_input, bw_unit_list, bw_mask, direc_index=1)
bw_rnn_out = layers.reverse(bw_rnn_out, axis=[0])
rnn_out = layers.concat([fw_rnn_out, bw_rnn_out], axis=2)
last_hidden = layers.concat([fw_last_hidden, bw_last_hidden], axis=1)
last_hidden = layers.reshape(
last_hidden, shape=[num_layers * direc_num, -1, hidden_size])
last_cell = layers.concat([fw_last_cell, bw_last_cell], axis=1)
last_cell = layers.reshape(
last_cell, shape=[num_layers * direc_num, -1, hidden_size])
if batch_first:
rnn_out = layers.transpose(rnn_out, [1, 0, 2])
return rnn_out, last_hidden, last_cell
else:
rnn_out = fw_rnn_out
last_hidden = fw_last_hidden
last_cell = fw_last_cell
if batch_first:
rnn_out = layers.transpose(rnn_out, [1, 0, 2])
return rnn_out, last_hidden, last_cell
class BasicLSTMUnit(Layer):
"""
****
BasicLSTMUnit class, Using basic operator to build LSTM
The algorithm can be described as the code below.
.. math::
i_t &= \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + b_i)
f_t &= \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + b_f + forget_bias )
o_t &= \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + b_o)
\\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
c_t &= f_t \odot c_{t-1} + i_t \odot \\tilde{c_t}
h_t &= o_t \odot tanh(c_t)
- $W$ terms denote weight matrices (e.g. $W_{ix}$ is the matrix
of weights from the input gate to the input)
- The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector).
- sigmoid is the logistic sigmoid function.
- $i, f, o$ and $c$ are the input gate, forget gate, output gate,
and cell activation vectors, respectively, all of which have the same size as
the cell output activation vector $h$.
- The :math:`\odot` is the element-wise product of the vectors.
- :math:`tanh` is the activation functions.
- :math:`\\tilde{c_t}` is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.
Args:
name_scope(string) : The name scope used to identify parameter and bias name
hidden_size (integer): The hidden size used in the Unit.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight matrix. Note:
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The parameter attribute for the bias
of LSTM unit.
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized as zero. Default: None.
gate_activation (function|None): The activation function for gates (actGate).
Default: 'fluid.layers.sigmoid'
activation (function|None): The activation function for cells (actNode).
Default: 'fluid.layers.tanh'
forget_bias(float|1.0): forget bias used when computing forget gate
dtype(string): data type used in this unit
Examples:
.. code-block:: python
import paddle.fluid.layers as layers
from paddle.fluid.contrib.layers import BasicLSTMUnit
input_size = 128
hidden_size = 256
input = layers.data( name = "input", shape = [-1, input_size], dtype='float32')
pre_hidden = layers.data( name = "pre_hidden", shape=[-1, hidden_size], dtype='float32')
pre_cell = layers.data( name = "pre_cell", shape=[-1, hidden_size], dtype='float32')
lstm_unit = BasicLSTMUnit( "gru_unit", hidden_size)
new_hidden, new_cell = lstm_unit( input, pre_hidden, pre_cell )
"""
def __init__(self,
name_scope,
hidden_size,
param_attr=None,
bias_attr=None,
gate_activation=None,
activation=None,
forget_bias=1.0,
dtype='float32'):
super(BasicLSTMUnit, self).__init__(name_scope, dtype)
# reserve old school _full_name and _helper for static graph save load
self._full_name = unique_name.generate(name_scope + "/" +
self.__class__.__name__)
self._helper = LayerObjectHelper(self._full_name)
self._name = name_scope
self._hiden_size = hidden_size
self._param_attr = param_attr
self._bias_attr = bias_attr
self._gate_activation = gate_activation or layers.sigmoid
self._activation = activation or layers.tanh
self._forget_bias = layers.fill_constant(
[1], dtype=dtype, value=forget_bias)
self._forget_bias.stop_gradient = False
self._dtype = dtype
def _build_once(self, input, pre_hidden, pre_cell):
self._input_size = input.shape[-1]
assert (self._input_size > 0)
self._weight = self.create_parameter(
attr=self._param_attr,
shape=[self._input_size + self._hiden_size, 4 * self._hiden_size],
dtype=self._dtype)
self._bias = self.create_parameter(
attr=self._bias_attr,
shape=[4 * self._hiden_size],
dtype=self._dtype,
is_bias=True)
def forward(self, input, pre_hidden, pre_cell):
concat_input_hidden = layers.concat([input, pre_hidden], 1)
gate_input = layers.matmul(x=concat_input_hidden, y=self._weight)
gate_input = layers.elementwise_add(gate_input, self._bias)
i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
new_cell = layers.elementwise_add(
layers.elementwise_mul(
pre_cell,
layers.sigmoid(layers.elementwise_add(f, self._forget_bias))),
layers.elementwise_mul(layers.sigmoid(i), layers.tanh(j)))
new_hidden = layers.tanh(new_cell) * layers.sigmoid(o)
return new_hidden, new_cell
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
gru class
"""
import paddle_layers as layers
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph import Layer
from paddle import fluid
import numpy as np
class GRU(Layer):
"""
GRU
"""
def __init__(self, conf_dict):
"""
initialize
"""
super(GRU, self).__init__()
self.dict_size = conf_dict["dict_size"]
self.task_mode = conf_dict["task_mode"]
self.emb_dim = conf_dict["net"]["emb_dim"]
self.gru_dim = conf_dict["net"]["gru_dim"]
self.hidden_dim = conf_dict["net"]["hidden_dim"]
self.emb_layer = layers.EmbeddingLayer(self.dict_size, self.emb_dim, "emb").ops()
self.gru_layer = layers.DynamicGRULayer(self.gru_dim, "gru").ops()
self.fc_layer = layers.FCLayer(self.hidden_dim, None, "fc").ops()
self.proj_layer = Linear(input_dim = self.hidden_dim, output_dim=self.gru_dim*3)
self.softmax_layer = layers.FCLayer(2, "softmax", "cos_sim").ops()
self.seq_len=5
def forward(self, left, right):
"""
Forward network
"""
# embedding layer
left_emb = self.emb_layer(left)
right_emb = self.emb_layer(right)
# Presentation context
left_emb = self.proj_layer(left_emb)
right_emb = self.proj_layer(right_emb)
h_0 = np.zeros((left_emb.shape[0], self.hidden_dim), dtype="float32")
h_0 = to_variable(h_0)
left_gru = self.gru_layer(left_emb, h_0=h_0)
right_gru = self.gru_layer(right_emb, h_0=h_0)
left_emb = fluid.layers.reduce_max(left_gru, dim=1)
right_emb = fluid.layers.reduce_max(right_gru, dim=1)
left_emb = fluid.layers.reshape(
left_emb, shape=[-1, self.seq_len, self.hidden_dim])
right_emb = fluid.layers.reshape(
right_emb, shape=[-1, self.seq_len, self.hidden_dim])
left_emb = fluid.layers.reduce_sum(left_emb, dim=1)
right_emb = fluid.layers.reduce_sum(right_emb, dim=1)
left_last = fluid.layers.tanh(left_emb)
right_last = fluid.layers.tanh(right_emb)
if self.task_mode == "pairwise":
left_fc = self.fc_layer(left_last)
right_fc = self.fc_layer(right_last)
cos_sim_layer = layers.CosSimLayer()
pred = cos_sim_layer.ops(left_fc, right_fc)
return left_fc, pred
else:
concat_layer = layers.ConcatLayer(1)
concat = concat_layer.ops([left_last, right_last])
concat_fc = self.fc_layer(concat)
pred = self.softmax_layer(concat_fc)
return left_last, pred
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
hinge loss
"""
import sys
sys.path.append("../")
import nets.paddle_layers as layers
class HingeLoss(object):
"""
Hing Loss Calculate class
"""
def __init__(self, conf_dict):
"""
initialize
"""
self.margin = conf_dict["loss"]["margin"]
def compute(self, pos, neg):
"""
compute loss
"""
elementwise_max = layers.ElementwiseMaxLayer()
elementwise_add = layers.ElementwiseAddLayer()
elementwise_sub = layers.ElementwiseSubLayer()
constant = layers.ConstantLayer()
reduce_mean = layers.ReduceMeanLayer()
loss = reduce_mean.ops(
elementwise_max.ops(
constant.ops(neg, neg.shape, "float32", 0.0),
elementwise_add.ops(
elementwise_sub.ops(neg, pos),
constant.ops(neg, neg.shape, "float32", self.margin))))
return loss
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
log loss
"""
import sys
sys.path.append("../")
import nets.paddle_layers as layers
class LogLoss(object):
"""
Log Loss Calculate
"""
def __init__(self, conf_dict):
"""
initialize
"""
pass
def compute(self, pos, neg):
"""
compute loss
"""
sigmoid = layers.SigmoidLayer()
reduce_mean = layers.ReduceMeanLayer()
loss = reduce_mean.ops(sigmoid.ops(neg - pos))
return loss
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
softmax loss
"""
import sys
import paddle.fluid as fluid
sys.path.append("../")
import nets.paddle_layers as layers
class SoftmaxCrossEntropyLoss(object):
"""
Softmax with Cross Entropy Loss Calculate
"""
def __init__(self, conf_dict):
"""
initialize
"""
pass
def compute(self, input, label):
"""
compute loss
"""
reduce_mean = layers.ReduceMeanLayer()
cost = fluid.layers.cross_entropy(input=input, label=label)
avg_cost = reduce_mean.ops(cost)
return avg_cost
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
lstm class
"""
import paddle_layers as layers
from paddle.fluid.dygraph import Layer, Linear
from paddle import fluid
class LSTM(Layer):
"""
LSTM
"""
def __init__(self, conf_dict):
"""
initialize
"""
super(LSTM,self).__init__()
self.dict_size = conf_dict["dict_size"]
self.task_mode = conf_dict["task_mode"]
self.emb_dim = conf_dict["net"]["emb_dim"]
self.lstm_dim = conf_dict["net"]["lstm_dim"]
self.hidden_dim = conf_dict["net"]["hidden_dim"]
self.emb_layer = layers.EmbeddingLayer(self.dict_size, self.emb_dim, "emb").ops()
self.lstm_layer = layers.DynamicLSTMLayer(self.lstm_dim, "lstm").ops()
self.fc_layer = layers.FCLayer(self.hidden_dim, None, "fc").ops()
self.softmax_layer = layers.FCLayer(2, "softmax", "cos_sim").ops()
self.proj_layer = Linear(input_dim = self.hidden_dim, output_dim=self.lstm_dim*4)
self.seq_len = 5
def forward(self, left, right):
"""
Forward network
"""
# embedding layer
left_emb = self.emb_layer(left)
right_emb = self.emb_layer(right)
# Presentation context
left_proj = self.proj_layer(left_emb)
right_proj = self.proj_layer(right_emb)
left_lstm, _ = self.lstm_layer(left_proj)
right_lstm, _ = self.lstm_layer(right_proj)
left_emb = fluid.layers.reduce_max(left_lstm, dim=1)
right_emb = fluid.layers.reduce_max(right_lstm, dim=1)
left_emb = fluid.layers.reshape(
left_emb, shape=[-1, self.seq_len, self.hidden_dim])
right_emb = fluid.layers.reshape(
right_emb, shape=[-1, self.seq_len, self.hidden_dim])
left_emb = fluid.layers.reduce_sum(left_emb, dim=1)
right_emb = fluid.layers.reduce_sum(right_emb, dim=1)
left_last = fluid.layers.tanh(left_emb)
right_last = fluid.layers.tanh(right_emb)
# matching layer
if self.task_mode == "pairwise":
left_fc = self.fc_layer(left_last)
right_fc = self.fc_layer(right_last)
cos_sim_layer = layers.CosSimLayer()
pred = cos_sim_layer.ops(left_fc, right_fc)
return left_fc, pred
else:
concat_layer = layers.ConcatLayer(1)
concat = concat_layer.ops([left_last, right_last])
concat_fc = self.fc_layer(concat)
pred = self.softmax_layer(concat_fc)
return left_last, pred
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
MMDNN class
"""
import numpy as np
import paddle.fluid as fluid
import logging
from paddle.fluid.dygraph import Linear, to_variable, Layer, Pool2D, Conv2D
import paddle_layers as pd_layers
from paddle.fluid import layers
class MMDNN(Layer):
"""
MMDNN
"""
def __init__(self, config):
"""
initialize
"""
super(MMDNN, self).__init__()
self.vocab_size = int(config['dict_size'])
self.emb_size = int(config['net']['embedding_dim'])
self.lstm_dim = int(config['net']['lstm_dim'])
self.kernel_size = int(config['net']['num_filters'])
self.win_size1 = int(config['net']['window_size_left'])
self.win_size2 = int(config['net']['window_size_right'])
self.dpool_size1 = int(config['net']['dpool_size_left'])
self.dpool_size2 = int(config['net']['dpool_size_right'])
self.hidden_size = int(config['net']['hidden_size'])
self.seq_len1 = 5
#int(config['max_len_left'])
self.seq_len2 = 5 #int(config['max_len_right'])
self.task_mode = config['task_mode']
self.zero_pad = True
self.scale = False
if int(config['match_mask']) != 0:
self.match_mask = True
else:
self.match_mask = False
if self.task_mode == "pointwise":
self.n_class = int(config['n_class'])
self.out_size = self.n_class
elif self.task_mode == "pairwise":
self.out_size = 1
else:
logging.error("training mode not supported")
# layers
self.emb_layer = pd_layers.EmbeddingLayer(self.vocab_size, self.emb_size,
name="word_embedding",padding_idx=(0 if self.zero_pad else None)).ops()
self.fw_in_proj = Linear(
input_dim=self.emb_size,
output_dim=4 * self.lstm_dim,
param_attr=fluid.ParamAttr(name="fw_fc.w"),
bias_attr=False)
self.lstm_layer = pd_layers.DynamicLSTMLayer(self.lstm_dim, "lstm").ops()
self.rv_in_proj = Linear(
input_dim=self.emb_size,
output_dim=4 * self.lstm_dim,
param_attr=fluid.ParamAttr(name="rv_fc.w"),
bias_attr=False)
self.reverse_layer = pd_layers.DynamicLSTMLayer(
self.lstm_dim,
is_reverse=True).ops()
self.conv = Conv2D(
num_channels=1,
num_filters=self.kernel_size,
stride=1,
padding=(int(self.seq_len1 / 2), int(self.seq_len2 // 2)),
filter_size=(self.seq_len1, self.seq_len2),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.1)))
self.pool_layer = Pool2D(
pool_size=[
int(self.seq_len1 / self.dpool_size1),
int(self.seq_len2 / self.dpool_size2)
],
pool_stride=[
int(self.seq_len1 / self.dpool_size1),
int(self.seq_len2 / self.dpool_size2)
],
pool_type="max" )
self.fc_layer = pd_layers.FCLayer(self.hidden_size, "tanh", "fc").ops()
self.fc1_layer = pd_layers.FCLayer(self.out_size, "softmax", "fc1").ops()
def forward(self, left, right):
"""
Forward network
"""
left_emb = self.emb_layer(left)
right_emb = self.emb_layer(right)
if self.scale:
left_emb = left_emb * (self.emb_size**0.5)
right_emb = right_emb * (self.emb_size**0.5)
# bi_listm
left_proj = self.fw_in_proj(left_emb)
right_proj = self.fw_in_proj(right_emb)
left_lstm, _ = self.lstm_layer(left_proj)
right_lstm, _ = self.lstm_layer(right_proj)
left_rv_proj = self.rv_in_proj(left_lstm)
right_rv_proj = self.rv_in_proj(right_lstm)
left_reverse,_ = self.reverse_layer(left_rv_proj)
right_reverse,_ = self.reverse_layer(right_rv_proj)
left_seq_encoder = fluid.layers.concat([left_lstm, left_reverse], axis=1)
right_seq_encoder = fluid.layers.concat([right_lstm, right_reverse], axis=1)
pad_value = fluid.layers.assign(input=np.array([0]).astype("float32"))
left_seq_encoder = fluid.layers.reshape(left_seq_encoder, shape=[left_seq_encoder.shape[0]/5,5,-1])
right_seq_encoder = fluid.layers.reshape(right_seq_encoder, shape=[right_seq_encoder.shape[0]/5,5,-1])
cross = fluid.layers.matmul(
left_seq_encoder, right_seq_encoder, transpose_y=True)
left_lens=to_variable(np.array([5]))
right_lens=to_variable(np.array([5]))
if self.match_mask:
mask1 = fluid.layers.sequence_mask(
x=left_lens, dtype='float32', maxlen=self.seq_len1 + 1)
mask2 = fluid.layers.sequence_mask(
x=right_lens, dtype='float32', maxlen=self.seq_len2 + 1)
mask1 = fluid.layers.transpose(x=mask1, perm=[1, 0])
mask = fluid.layers.matmul(x=mask1, y=mask2)
else:
mask = None
# conv_pool_relu
emb_expand = fluid.layers.unsqueeze(input=cross, axes=[1])
conv = self.conv(emb_expand)
if mask is not None:
cross_mask = fluid.layers.stack(x=[mask] * self.kernel_size, axis=0)
cross_mask = fluid.layers.stack(x=[cross] * conv.shape[1], axis=1)
conv = cross_mask * conv + (1 - cross_mask) * (-2**5 + 1)
pool = self.pool_layer(conv)
conv_pool_relu = fluid.layers.relu(pool)
relu_hid1 = self.fc_layer(conv_pool_relu)
relu_hid1 = fluid.layers.tanh(relu_hid1)
pred = self.fc1_layer(relu_hid1)
pred = fluid.layers.softmax(pred)
return left_seq_encoder, pred
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
network layers
"""
import paddle.fluid as fluid
import paddle.fluid.param_attr as attr
class EmbeddingLayer(object):
"""
Embedding Layer class
"""
def __init__(self, dict_size, emb_dim, name="emb"):
"""
initialize
"""
self.dict_size = dict_size
self.emb_dim = emb_dim
self.name = name
def ops(self, input):
"""
operation
"""
emb = fluid.dygraph.Embedding(
input=input,
size=[self.dict_size, self.emb_dim],
is_sparse=True,
param_attr=attr.ParamAttr(name=self.name))
return emb
class SequencePoolLayer(object):
"""
Sequence Pool Layer class
"""
def __init__(self, pool_type):
"""
initialize
"""
self.pool_type = pool_type
def ops(self, input):
"""
operation
"""
pool = fluid.dygraph.Pool2D(input=input, pool_type=self.pool_type)
return pool
class FCLayer(object):
"""
Fully Connect Layer class
"""
def __init__(self, fc_dim, act, name="fc"):
"""
initialize
"""
self.fc_dim = fc_dim
self.act = act
self.name = name
def ops(self, input):
"""
operation
"""
fc = fluid.dygraph.FC(input=input,
size=self.fc_dim,
param_attr=attr.ParamAttr(name="%s.w" % self.name),
bias_attr=attr.ParamAttr(name="%s.b" % self.name),
act=self.act,
name=self.name)
return fc
class DynamicGRULayer(object):
"""
Dynamic GRU Layer class
"""
def __init__(self, gru_dim, name="dyn_gru"):
"""
initialize
"""
self.gru_dim = gru_dim
self.name = name
def ops(self, input):
"""
operation
"""
proj = fluid.dygraph.FC(
input=input,
size=self.gru_dim * 3,
param_attr=attr.ParamAttr(name="%s_fc.w" % self.name),
bias_attr=attr.ParamAttr(name="%s_fc.b" % self.name))
gru = fluid.layers.dynamic_gru(
input=proj,
size=self.gru_dim,
param_attr=attr.ParamAttr(name="%s.w" % self.name),
bias_attr=attr.ParamAttr(name="%s.b" % self.name))
return gru
class DynamicLSTMLayer(object):
"""
Dynamic LSTM Layer class
"""
def __init__(self, lstm_dim, name="dyn_lstm"):
"""
initialize
"""
self.lstm_dim = lstm_dim
self.name = name
def ops(self, input):
"""
operation
"""
proj = fluid.dygraph.FC(
input=input,
size=self.lstm_dim * 4,
param_attr=attr.ParamAttr(name="%s_fc.w" % self.name),
bias_attr=attr.ParamAttr(name="%s_fc.b" % self.name))
lstm, _ = fluid.layers.dynamic_lstm(
input=proj,
size=self.lstm_dim * 4,
param_attr=attr.ParamAttr(name="%s.w" % self.name),
bias_attr=attr.ParamAttr(name="%s.b" % self.name))
return lstm
class SequenceLastStepLayer(object):
"""
Get Last Step Sequence Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, input):
"""
operation
"""
last = fluid.layers.sequence_last_step(input)
return last
class SequenceConvPoolLayer(object):
"""
Sequence convolution and pooling Layer class
"""
def __init__(self, filter_size, num_filters, name):
"""
initialize
Args:
filter_size:Convolution kernel size
num_filters:Convolution kernel number
"""
self.filter_size = filter_size
self.num_filters = num_filters
self.name = name
def ops(self, input):
"""
operation
"""
conv = fluid.nets.sequence_conv_pool(
input=input,
filter_size=self.filter_size,
num_filters=self.num_filters,
param_attr=attr.ParamAttr(name=self.name),
act="relu")
return conv
class DataLayer(object):
"""
Data Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, name, shape, dtype, lod_level=0):
"""
operation
"""
data = fluid.layers.data( #不用改
name=name, shape=shape, dtype=dtype, lod_level=lod_level)
return data
class ConcatLayer(object):
"""
Connection Layer class
"""
def __init__(self, axis):
"""
initialize
"""
self.axis = axis
def ops(self, inputs):
"""
operation
"""
concat = fluid.layers.concat(inputs, axis=self.axis)
return concat
class ReduceMeanLayer(object):
"""
Reduce Mean Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, input):
"""
operation
"""
mean = fluid.layers.reduce_mean(input)
return mean
class CrossEntropyLayer(object):
"""
Cross Entropy Calculate Layer
"""
def __init__(self, name="cross_entropy"):
"""
initialize
"""
pass
def ops(self, input, label):
"""
operation
"""
loss = fluid.layers.cross_entropy(input=input, label=label) # 不用改
return loss
class SoftmaxWithCrossEntropyLayer(object):
"""
Softmax with Cross Entropy Calculate Layer
"""
def __init__(self, name="softmax_with_cross_entropy"):
"""
initialize
"""
pass
def ops(self, input, label):
"""
operation
"""
loss = fluid.layers.softmax_with_cross_entropy( # 不用改
logits=input, label=label)
return loss
class CosSimLayer(object):
"""
Cos Similarly Calculate Layer
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, x, y):
"""
operation
"""
sim = fluid.layers.cos_sim(x, y)
return sim
class ElementwiseMaxLayer(object):
"""
Elementwise Max Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, x, y):
"""
operation
"""
max = fluid.layers.elementwise_max(x, y)
return max
class ElementwiseAddLayer(object):
"""
Elementwise Add Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, x, y):
"""
operation
"""
add = fluid.layers.elementwise_add(x, y)
return add
class ElementwiseSubLayer(object):
"""
Elementwise Add Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, x, y):
"""
operation
"""
sub = fluid.layers.elementwise_sub(x, y)
return sub
class ConstantLayer(object):
"""
Generate A Constant Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, input, shape, dtype, value):
"""
operation
"""
constant = fluid.layers.fill_constant_batch_size_like(input, shape,
dtype, value)
return constant
class SigmoidLayer(object):
"""
Sigmoid Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, input):
"""
operation
"""
sigmoid = fluid.layers.sigmoid(input)
return sigmoid
class SoftsignLayer(object):
"""
Softsign Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, input):
"""
operation
"""
softsign = fluid.layers.softsign(input)
return softsign
# class MatmulLayer(object):
# def __init__(self, transpose_x, transpose_y):
# self.transpose_x = transpose_x
# self.transpose_y = transpose_y
# def ops(self, x, y):
# matmul = fluid.layers.matmul(x, y, self.transpose_x, self.transpose_y)
# return matmul
# class Conv2dLayer(object):
# def __init__(self, num_filters, filter_size, act, name):
# self.num_filters = num_filters
# self.filter_size = filter_size
# self.act = act
# self.name = name
# def ops(self, input):
# conv = fluid.layers.conv2d(input, self.num_filters, self.filter_size, param_attr=attr.ParamAttr(name="%s.w" % self.name), bias_attr=attr.ParamAttr(name="%s.b" % self.name), act=self.act)
# return conv
# class Pool2dLayer(object):
# def __init__(self, pool_size, pool_type):
# self.pool_size = pool_size
# self.pool_type = pool_type
# def ops(self, input):
# pool = fluid.layers.pool2d(input, self.pool_size, self.pool_type)
# return pool
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
network layers
"""
import collections
import contextlib
import inspect
import six
import sys
from functools import partial
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import layers
import paddle.fluid.param_attr as attr
import paddle.fluid.layers.utils as utils
from paddle.fluid.dygraph import Embedding, Pool2D, Linear, Conv2D, GRUUnit, Layer, to_variable
from paddle.fluid.layers.utils import map_structure, flatten, pack_sequence_as
class EmbeddingLayer(object):
"""
Embedding Layer class
"""
def __init__(self, dict_size, emb_dim, name="emb", padding_idx=None):
"""
initialize
"""
self.dict_size = dict_size
self.emb_dim = emb_dim
self.name = name
self.padding_idx = padding_idx
def ops(self):
"""
operation
"""
# name = self.name
emb = Embedding(
size=[self.dict_size, self.emb_dim],
is_sparse=True,
padding_idx=self.padding_idx,
param_attr=attr.ParamAttr(name=self.name, initializer=fluid.initializer.Xavier()))
return emb
class FCLayer(object):
"""
Fully Connect Layer class
"""
def __init__(self, fc_dim, act, name="fc"):
"""
initialize
"""
self.fc_dim = fc_dim
self.act = act
self.name = name
def ops(self):
"""
operation
"""
fc = FC(size=self.fc_dim,
param_attr=attr.ParamAttr(name="%s.w" % self.name),
bias_attr=attr.ParamAttr(name="%s.b" % self.name),
act=self.act)
return fc
class DynamicGRULayer(object):
"""
Dynamic GRU Layer class
"""
def __init__(self, gru_dim, name="dyn_gru"):
"""
initialize
"""
self.gru_dim = gru_dim
self.name = name
def ops(self):
"""
operation
"""
gru = DynamicGRU(
size=self.gru_dim,
param_attr=attr.ParamAttr(name="%s.w" % self.name),
bias_attr=attr.ParamAttr(name="%s.b" % self.name))
return gru
class DynamicLSTMLayer(object):
"""
Dynamic LSTM Layer class
"""
def __init__(self, lstm_dim, name="dyn_lstm", is_reverse=False):
"""
initialize
"""
self.lstm_dim = lstm_dim
self.name = name
self.is_reverse = is_reverse
def ops(self):
"""
operation
"""
lstm_cell = BasicLSTMUnit(hidden_size=self.lstm_dim, input_size=self.lstm_dim*4)
lstm = RNN(cell=lstm_cell, time_major=True, is_reverse=self.is_reverse)
return lstm
class DataLayer(object):
"""
Data Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, name, shape, dtype, lod_level=0):
"""
operation
"""
data = fluid.layers.data(
name=name, shape=shape, dtype=dtype, lod_level=lod_level)
return data
class ConcatLayer(object):
"""
Connection Layer class
"""
def __init__(self, axis):
"""
initialize
"""
self.axis = axis
def ops(self, inputs):
"""
operation
"""
concat = fluid.layers.concat(inputs, axis=self.axis)
return concat
class ReduceMeanLayer(object):
"""
Reduce Mean Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, input):
"""
operation
"""
mean = fluid.layers.reduce_mean(input)
return mean
class CrossEntropyLayer(object):
"""
Cross Entropy Calculate Layer
"""
def __init__(self, name="cross_entropy"):
"""
initialize
"""
pass
def ops(self, input, label):
"""
operation
"""
loss = fluid.layers.cross_entropy(input=input, label=label) # no need
return loss
class SoftmaxWithCrossEntropyLayer(object):
"""
Softmax with Cross Entropy Calculate Layer
"""
def __init__(self, name="softmax_with_cross_entropy"):
"""
initialize
"""
pass
def ops(self, input, label):
"""
operation
"""
loss = fluid.layers.softmax_with_cross_entropy( # no need
logits=input, label=label)
return loss
class CosSimLayer(object):
"""
Cos Similarly Calculate Layer
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, x, y):
"""
operation
"""
sim = fluid.layers.cos_sim(x, y)
return sim
class ElementwiseMaxLayer(object):
"""
Elementwise Max Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, x, y):
"""
operation
"""
max = fluid.layers.elementwise_max(x, y)
return max
class ElementwiseAddLayer(object):
"""
Elementwise Add Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, x, y):
"""
operation
"""
add = fluid.layers.elementwise_add(x, y)
return add
class ElementwiseSubLayer(object):
"""
Elementwise Add Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, x, y):
"""
operation
"""
sub = fluid.layers.elementwise_sub(x, y)
return sub
class ConstantLayer(object):
"""
Generate A Constant Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, input, shape, dtype, value):
"""
operation
"""
constant = fluid.layers.fill_constant_batch_size_like(input, shape,
dtype, value)
return constant
class SigmoidLayer(object):
"""
Sigmoid Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, input):
"""
operation
"""
sigmoid = fluid.layers.sigmoid(input)
return sigmoid
class SoftsignLayer(object):
"""
Softsign Layer class
"""
def __init__(self):
"""
initialize
"""
pass
def ops(self, input):
"""
operation
"""
softsign = fluid.layers.softsign(input)
return softsign
# dygraph
class SimpleConvPool(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
use_cudnn=False
):
super(SimpleConvPool, self).__init__()
self._conv2d = Conv2D(num_channels = num_channels,
num_filters=num_filters,
filter_size=filter_size,
padding=[1, 1],
use_cudnn=use_cudnn,
act='relu')
def forward(self, inputs):
x = self._conv2d(inputs)
x = fluid.layers.reduce_max(x, dim=-1)
x = fluid.layers.reshape(x, shape=[x.shape[0], -1])
return x
class FC(Layer):
"""
This interface is used to construct a callable object of the ``FC`` class.
For more details, refer to code examples.
It creates a fully connected layer in the network. It can take
one or multiple ``Tensor`` as its inputs. It creates a Variable called weights for each input tensor,
which represents a fully connected weight matrix from each input unit to
each output unit. The fully connected layer multiplies each input tensor
with its corresponding weight to produce an output Tensor with shape [N, `size`],
where N is batch size. If multiple input tensors are given, the results of
multiple output tensors with shape [N, `size`] will be summed up. If ``bias_attr``
is not None, a bias variable will be created and added to the output.
Finally, if ``act`` is not None, it will be applied to the output as well.
When the input is single ``Tensor`` :
.. math::
Out = Act({XW + b})
When the input are multiple ``Tensor`` :
.. math::
Out = Act({\sum_{i=0}^{N-1}X_iW_i + b})
In the above equation:
* :math:`N`: Number of the input. N equals to len(input) if input is list of ``Tensor`` .
* :math:`X_i`: The i-th input ``Tensor`` .
* :math:`W_i`: The i-th weights matrix corresponding i-th input tensor.
* :math:`b`: The bias parameter created by this layer (if needed).
* :math:`Act`: The activation function.
* :math:`Out`: The output ``Tensor`` .
See below for an example.
.. code-block:: text
Given:
data_1.data = [[[0.1, 0.2]]]
data_1.shape = (1, 1, 2) # 1 is batch_size
data_2.data = [[[0.1, 0.2, 0.3]]]
data_2.shape = (1, 1, 3) # 1 is batch_size
fc = FC("fc", 2, num_flatten_dims=2)
out = fc(input=[data_1, data_2])
Then:
out.data = [[[0.182996 -0.474117]]]
out.shape = (1, 1, 2)
Parameters:
size(int): The number of output units in this layer.
num_flatten_dims (int, optional): The fc layer can accept an input tensor with more than
two dimensions. If this happens, the multi-dimension tensor will first be flattened
into a 2-dimensional matrix. The parameter `num_flatten_dims` determines how the input
tensor is flattened: the first `num_flatten_dims` (inclusive, index starts from 1)
dimensions will be flatten to form the first dimension of the final matrix (height of
the matrix), and the rest `rank(X) - num_flatten_dims` dimensions are flattened to
form the second dimension of the final matrix (width of the matrix). For example, suppose
`X` is a 5-dimensional tensor with a shape [2, 3, 4, 5, 6], and `num_flatten_dims` = 3.
Then, the flattened matrix will have a shape [2 x 3 x 4, 5 x 6] = [24, 30]. Default: 1
param_attr (ParamAttr or list of ParamAttr, optional): The parameter attribute for learnable
weights(Parameter) of this layer. Default: None.
bias_attr (ParamAttr or list of ParamAttr, optional): The attribute for the bias
of this layer. If it is set to False, no bias will be added to the output units.
If it is set to None, the bias is initialized zero. Default: None.
act (str, optional): Activation to be applied to the output of this layer. Default: None.
is_test(bool, optional): A flag indicating whether execution is in test phase. Default: False.
dtype(str, optional): Dtype used for weight, it can be "float32" or "float64". Default: "float32".
Attribute:
**weight** (list of Parameter): the learnable weights of this layer.
**bias** (Parameter or None): the learnable bias of this layer.
Returns:
None
Examples:
.. code-block:: python
from paddle.fluid.dygraph.base import to_variable
import paddle.fluid as fluid
from paddle.fluid.dygraph import FC
import numpy as np
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
with fluid.dygraph.guard():
fc = FC("fc", 64, num_flatten_dims=2)
data = to_variable(data)
conv = fc(data)
"""
def __init__(self,
size,
num_flatten_dims=1,
param_attr=None,
bias_attr=None,
act=None,
is_test=False,
dtype="float32"):
super(FC, self).__init__(dtype)
self._size = size
self._num_flatten_dims = num_flatten_dims
self._dtype = dtype
self._param_attr = param_attr
self._bias_attr = bias_attr
self._act = act
self.__w = list()
def _build_once(self, input):
i = 0
for inp, param in self._helper.iter_inputs_and_params(input,
self._param_attr):
input_shape = inp.shape
param_shape = [
reduce(lambda a, b: a * b, input_shape[self._num_flatten_dims:],
1)
] + [self._size]
self.__w.append(
self.add_parameter(
'_w%d' % i,
self.create_parameter(
attr=param,
shape=param_shape,
dtype=self._dtype,
is_bias=False)))
i += 1
size = list([self._size])
self._b = self.create_parameter(
attr=self._bias_attr, shape=size, dtype=self._dtype, is_bias=True)
# TODO(songyouwei): We should remove _w property
@property
def _w(self, i=0):
return self.__w[i]
@_w.setter
def _w(self, value, i=0):
assert isinstance(self.__w[i], Variable)
self.__w[i].set_value(value)
@property
def weight(self):
if len(self.__w) > 1:
return self.__w
else:
return self.__w[0]
@weight.setter
def weight(self, value):
if len(self.__w) == 1:
self.__w[0] = value
@property
def bias(self):
return self._b
@bias.setter
def bias(self, value):
self._b = value
def forward(self, input):
mul_results = list()
i = 0
for inp, param in self._helper.iter_inputs_and_params(input,
self._param_attr):
tmp = self._helper.create_variable_for_type_inference(self._dtype)
self._helper.append_op(
type="mul",
inputs={"X": inp,
"Y": self.__w[i]},
outputs={"Out": tmp},
attrs={
"x_num_col_dims": self._num_flatten_dims,
"y_num_col_dims": 1
})
i += 1
mul_results.append(tmp)
if len(mul_results) == 1:
pre_bias = mul_results[0]
else:
pre_bias = self._helper.create_variable_for_type_inference(
self._dtype)
self._helper.append_op(
type="sum",
inputs={"X": mul_results},
outputs={"Out": pre_bias},
attrs={"use_mkldnn": False})
if self._b:
pre_activation = self._helper.create_variable_for_type_inference(
dtype=self._dtype)
self._helper.append_op(
type='elementwise_add',
inputs={'X': [pre_bias],
'Y': [self._b]},
outputs={'Out': [pre_activation]},
attrs={'axis': self._num_flatten_dims})
else:
pre_activation = pre_bias
# Currently, we don't support inplace in dygraph mode
return self._helper.append_activation(pre_activation, act=self._act)
class DynamicGRU(Layer):
def __init__(self,
size,
param_attr=None,
bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
origin_mode=False,
init_size = None):
super(DynamicGRU, self).__init__()
self.gru_unit = GRUUnit(
size * 3,
param_attr=param_attr,
bias_attr=bias_attr,
activation=candidate_activation,
gate_activation=gate_activation,
origin_mode=origin_mode)
self.size = size
self.is_reverse = is_reverse
def forward(self, inputs, h_0):
hidden = h_0
res = []
for i in range(inputs.shape[1]):
if self.is_reverse:
i = inputs.shape[1] - 1 - i
input_ = inputs[ :, i:i+1, :]
input_ = fluid.layers.reshape(input_, [-1, input_.shape[2]], inplace=False)
hidden, reset, gate = self.gru_unit(input_, hidden)
hidden_ = fluid.layers.reshape(hidden, [-1, 1, hidden.shape[1]], inplace=False)
res.append(hidden_)
if self.is_reverse:
res = res[::-1]
res = fluid.layers.concat(res, axis=1)
return res
class RNNUnit(Layer):
def get_initial_states(self,
batch_ref,
shape=None,
dtype=None,
init_value=0,
batch_dim_idx=0):
"""
Generate initialized states according to provided shape, data type and
value.
Parameters:
batch_ref: A (possibly nested structure of) tensor variable[s].
The first dimension of the tensor will be used as batch size to
initialize states.
shape: A (possiblely nested structure of) shape[s], where a shape is
represented as a list/tuple of integer). -1(for batch size) will
beautomatically inserted if shape is not started with it. If None,
property `state_shape` will be used. The default value is None.
dtype: A (possiblely nested structure of) data type[s]. The structure
must be same as that of `shape`, except when all tensors' in states
has the same data type, a single data type can be used. If None and
property `cell.state_shape` is not available, float32 will be used
as the data type. The default value is None.
init_value: A float value used to initialize states.
Returns:
Variable: tensor variable[s] packed in the same structure provided \
by shape, representing the initialized states.
"""
# TODO: use inputs and batch_size
batch_ref = flatten(batch_ref)[0]
def _is_shape_sequence(seq):
if sys.version_info < (3, ):
integer_types = (
int,
long, )
else:
integer_types = (int, )
"""For shape, list/tuple of integer is the finest-grained objection"""
if (isinstance(seq, list) or isinstance(seq, tuple)):
if reduce(lambda flag, x: isinstance(x, integer_types) and flag,
seq, True):
return False
# TODO: Add check for the illegal
if isinstance(seq, dict):
return True
return (isinstance(seq, collections.Sequence) and
not isinstance(seq, six.string_types))
class Shape(object):
def __init__(self, shape):
self.shape = shape if shape[0] == -1 else ([-1] + list(shape))
# nested structure of shapes
states_shapes = self.state_shape if shape is None else shape
is_sequence_ori = utils.is_sequence
utils.is_sequence = _is_shape_sequence
states_shapes = map_structure(lambda shape: Shape(shape), states_shapes)
utils.is_sequence = is_sequence_ori
# nested structure of dtypes
try:
states_dtypes = self.state_dtype if dtype is None else dtype
except NotImplementedError: # use fp32 as default
states_dtypes = "float32"
if len(flatten(states_dtypes)) == 1:
dtype = flatten(states_dtypes)[0]
states_dtypes = map_structure(lambda shape: dtype, states_shapes)
init_states = map_structure(
lambda shape, dtype: fluid.layers.fill_constant_batch_size_like(
input=batch_ref,
shape=shape.shape,
dtype=dtype,
value=init_value,
input_dim_idx=batch_dim_idx), states_shapes, states_dtypes)
return init_states
@property
def state_shape(self):
"""
Abstract method (property).
Used to initialize states.
A (possiblely nested structure of) shape[s], where a shape is represented
as a list/tuple of integers (-1 for batch size would be automatically
inserted into a shape if shape is not started with it).
Not necessary to be implemented if states are not initialized by
`get_initial_states` or the `shape` argument is provided when using
`get_initial_states`.
"""
raise NotImplementedError(
"Please add implementaion for `state_shape` in the used cell.")
@property
def state_dtype(self):
"""
Abstract method (property).
Used to initialize states.
A (possiblely nested structure of) data types[s]. The structure must be
same as that of `shape`, except when all tensors' in states has the same
data type, a signle data type can be used.
Not necessary to be implemented if states are not initialized
by `get_initial_states` or the `dtype` argument is provided when using
`get_initial_states`.
"""
raise NotImplementedError(
"Please add implementaion for `state_dtype` in the used cell.")
class BasicLSTMUnit(RNNUnit):
"""
****
BasicLSTMUnit class, Using basic operator to build LSTM
The algorithm can be described as the code below.
.. math::
i_t &= \sigma(W_{ix}x_{t} + W_{ih}h_{t-1} + b_i)
f_t &= \sigma(W_{fx}x_{t} + W_{fh}h_{t-1} + b_f + forget_bias )
o_t &= \sigma(W_{ox}x_{t} + W_{oh}h_{t-1} + b_o)
\\tilde{c_t} &= tanh(W_{cx}x_t + W_{ch}h_{t-1} + b_c)
c_t &= f_t \odot c_{t-1} + i_t \odot \\tilde{c_t}
h_t &= o_t \odot tanh(c_t)
- $W$ terms denote weight matrices (e.g. $W_{ix}$ is the matrix
of weights from the input gate to the input)
- The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector).
- sigmoid is the logistic sigmoid function.
- $i, f, o$ and $c$ are the input gate, forget gate, output gate,
and cell activation vectors, respectively, all of which have the same size as
the cell output activation vector $h$.
- The :math:`\odot` is the element-wise product of the vectors.
- :math:`tanh` is the activation functions.
- :math:`\\tilde{c_t}` is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.
Args:
hidden_size (integer): The hidden size used in the Unit.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight matrix. Note:
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The parameter attribute for the bias
of LSTM unit.
If it is set to None or one attribute of ParamAttr, lstm_unit will
create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized as zero. Default: None.
gate_activation (function|None): The activation function for gates (actGate).
Default: 'fluid.layers.sigmoid'
activation (function|None): The activation function for cells (actNode).
Default: 'fluid.layers.tanh'
forget_bias(float|1.0): forget bias used when computing forget gate
dtype(string): data type used in this unit
"""
def __init__(self,
hidden_size,
input_size,
param_attr=None,
bias_attr=None,
gate_activation=None,
activation=None,
forget_bias=1.0,
dtype='float32'):
super(BasicLSTMUnit, self).__init__(dtype)
self._hidden_size = hidden_size
self._param_attr = param_attr
self._bias_attr = bias_attr
self._gate_activation = gate_activation or layers.sigmoid
self._activation = activation or layers.tanh
self._forget_bias = layers.fill_constant(
[1], dtype=dtype, value=forget_bias)
self._forget_bias.stop_gradient = False
self._dtype = dtype
self._input_size = input_size
self._weight = self.create_parameter(
attr=self._param_attr,
shape=[self._input_size + self._hidden_size, 4 * self._hidden_size],
dtype=self._dtype)
self._bias = self.create_parameter(attr=self._bias_attr,
shape=[4 * self._hidden_size],
dtype=self._dtype,
is_bias=True)
def forward(self, input, state):
pre_hidden, pre_cell = state
concat_input_hidden = layers.concat([input, pre_hidden], axis=1)
gate_input = layers.matmul(x=concat_input_hidden, y=self._weight)
gate_input = layers.elementwise_add(gate_input, self._bias)
i, j, f, o = layers.split(gate_input, num_or_sections=4, dim=-1)
new_cell = layers.elementwise_add(
layers.elementwise_mul(
pre_cell,
layers.sigmoid(layers.elementwise_add(f, self._forget_bias))),
layers.elementwise_mul(layers.sigmoid(i), layers.tanh(j)))
new_hidden = layers.tanh(new_cell) * layers.sigmoid(o)
return new_hidden, [new_hidden, new_cell]
@property
def state_shape(self):
return [[self._hidden_size], [self._hidden_size]]
class RNN(Layer):
def __init__(self,
cell,
is_reverse=False,
time_major=False,
**kwargs):
super(RNN, self).__init__()
self.cell = cell
if not hasattr(self.cell, "call"):
self.cell.call = self.cell.forward
self.is_reverse = is_reverse
self.time_major = time_major
self.batch_index, self.time_step_index = (1, 0) if time_major else (0,
1)
def forward(self, inputs, initial_states=None, sequence_length=None, **kwargs):
if fluid.in_dygraph_mode():
class OutputArray(object):
def __init__(self, x):
self.array = [x]
def append(self, x):
self.array.append(x)
def _maybe_copy(state, new_state, step_mask):
# TODO: use where_op
new_state = fluid.layers.elementwise_mul(
new_state, step_mask,
axis=0) - fluid.layers.elementwise_mul(state,
(step_mask - 1),
axis=0)
return new_state
flat_inputs = flatten(inputs)
batch_size, time_steps = (
flat_inputs[0].shape[self.batch_index],
flat_inputs[0].shape[self.time_step_index])
if initial_states is None:
initial_states = self.cell.get_initial_states(
batch_ref=inputs, batch_dim_idx=self.batch_index)
if not self.time_major:
inputs = map_structure(
lambda x: fluid.layers.transpose(x, [1, 0] + list(
range(2, len(x.shape)))), inputs)
if sequence_length:
mask = fluid.layers.sequence_mask(
sequence_length,
maxlen=time_steps,
dtype=flatten(initial_states)[0].dtype)
mask = fluid.layers.transpose(mask, [1, 0])
if self.is_reverse:
inputs = map_structure(lambda x: fluid.layers.reverse(x, axis=[0]), inputs)
mask = fluid.layers.reverse(mask, axis=[0]) if sequence_length else None
states = initial_states
outputs = []
for i in range(time_steps):
step_inputs = map_structure(lambda x:x[i], inputs)
step_outputs, new_states = self.cell(step_inputs, states, **kwargs)
if sequence_length:
new_states = map_structure(
partial(_maybe_copy, step_mask=mask[i]), states,
new_states)
states = new_states
if i == 0:
outputs = map_structure(lambda x: OutputArray(x),
step_outputs)
else:
map_structure(lambda x, x_array: x_array.append(x),
step_outputs, outputs)
final_outputs = map_structure(
lambda x: fluid.layers.stack(x.array, axis=self.time_step_index
), outputs)
if self.is_reverse:
final_outputs = map_structure(
lambda x: fluid.layers.reverse(x, axis=self.time_step_index
), final_outputs)
final_states = new_states
else:
final_outputs, final_states = fluid.layers.rnn(
self.cell,
inputs,
initial_states=initial_states,
sequence_length=sequence_length,
time_major=self.time_major,
is_reverse=self.is_reverse,
**kwargs)
return final_outputs, final_states
from paddle.fluid.dygraph import Embedding, LayerNorm, Linear, Layer, to_variable
place = fluid.CPUPlace()
executor = fluid.Executor(place)
class EncoderCell(RNNUnit):
def __init__(self, num_layers, input_size, hidden_size, dropout_prob=0.):
super(EncoderCell, self).__init__()
self.num_layers = num_layers
self.dropout_prob = dropout_prob
self.lstm_cells = list()
for i in range(self.num_layers):
self.lstm_cells.append(
self.add_sublayer(
"layer_%d" % i,
BasicLSTMUnit(input_size if i == 0 else hidden_size,
hidden_size)))
def forward(self, step_input, states):
new_states = []
for i in range(self.num_layers):
out, new_state = self.lstm_cells[i](step_input, states[i])
step_input = layers.dropout(
out, self.dropout_prob) if self.dropout_prob > 0 else out
new_states.append(new_state)
return step_input, new_states
@property
def state_shape(self):
return [cell.state_shape for cell in self.lstm_cells]
class BasicGRUUnit(Layer):
"""
****
BasicGRUUnit class, using basic operators to build GRU
The algorithm can be described as the equations below.
.. math::
u_t & = actGate(W_ux xu_{t} + W_uh h_{t-1} + b_u)
r_t & = actGate(W_rx xr_{t} + W_rh h_{t-1} + b_r)
m_t & = actNode(W_cx xm_t + W_ch dot(r_t, h_{t-1}) + b_m)
h_t & = dot(u_t, h_{t-1}) + dot((1-u_t), m_t)
Args:
hidden_size (integer): The hidden size used in the Unit.
param_attr(ParamAttr|None): The parameter attribute for the learnable
weight matrix. Note:
If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as param_attr. If the Initializer of the param_attr
is not set, the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr|None): The parameter attribute for the bias
of GRU unit.
If it is set to None or one attribute of ParamAttr, gru_unit will
create ParamAttr as bias_attr. If the Initializer of the bias_attr
is not set, the bias is initialized zero. Default: None.
gate_activation (function|None): The activation function for gates (actGate).
Default: 'fluid.layers.sigmoid'
activation (function|None): The activation function for cell (actNode).
Default: 'fluid.layers.tanh'
dtype(string): data type used in this unit
Examples:
.. code-block:: python
import paddle.fluid.layers as layers
from paddle.fluid.contrib.layers import BasicGRUUnit
input_size = 128
hidden_size = 256
input = layers.data( name = "input", shape = [-1, input_size], dtype='float32')
pre_hidden = layers.data( name = "pre_hidden", shape=[-1, hidden_size], dtype='float32')
gru_unit = BasicGRUUnit( "gru_unit", hidden_size )
new_hidden = gru_unit( input, pre_hidden )
"""
def __init__(self,
hidden_size,
input_size,
param_attr=None,
bias_attr=None,
gate_activation=None,
activation=None,
dtype='float32'):
super(BasicGRUUnit, self).__init__(dtype)
self._hiden_size = hidden_size
self._input_size = input_size
self._param_attr = param_attr
self._bias_attr = bias_attr
self._gate_activation = gate_activation or layers.sigmoid
self._activation = activation or layers.tanh
self._dtype = dtype
self._gate_weight = self.create_parameter(
attr=self._param_attr,
shape=[self._input_size + self._hiden_size, 2 * self._hiden_size],
dtype=self._dtype)
self._candidate_weight = self.create_parameter(
attr=self._param_attr,
shape=[self._input_size + self._hiden_size, self._hiden_size],
dtype=self._dtype)
self._gate_bias = self.create_parameter(
attr=self._bias_attr,
shape=[2 * self._hiden_size],
dtype=self._dtype,
is_bias=True)
self._candidate_bias = self.create_parameter(
attr=self._bias_attr,
shape=[self._hiden_size],
dtype=self._dtype,
is_bias=True)
def forward(self, input, state):
pre_hidden = state
concat_input_hidden = fluid.layers.concat([input, pre_hidden], axis=1)
gate_input = layers.matmul(x=concat_input_hidden, y=self._gate_weight)
gate_input = layers.elementwise_add(gate_input, self._gate_bias)
gate_input = self._gate_activation(gate_input)
r, u = layers.split(gate_input, num_or_sections=2, dim=1)
r_hidden = r * pre_hidden
candidate = layers.matmul(
layers.concat([input, r_hidden], 1), self._candidate_weight)
candidate = layers.elementwise_add(candidate, self._candidate_bias)
c = self._activation(candidate)
new_hidden = u * pre_hidden + (1 - u) * c
return new_hidden
###### DELETE
# @contextlib.contextmanager
# def eager_guard(is_eager):
# if is_eager:
# with fluid.dygraph.guard():
# yield
# else:
# yield
# # print(flatten(np.random.rand(2,8,8)))
# random_seed = 123
# np.random.seed(random_seed)
# # print np.random.rand(2, 8)
# batch_size = 2
# seq_len = 8
# hidden_size = 8
# vocab_size, embed_dim, num_layers, hidden_size = 100, 8, 2, 8
# import torch
# with eager_guard(False):
# fluid.default_main_program().random_seed = random_seed
# fluid.default_startup_program().random_seed = random_seed
# lstm_cell = BasicLSTMUnit(hidden_size=8, input_size=8)
# lstm = RNN(cell=lstm_cell, time_major=True)
# #print lstm(inputs=to_variable(np.random.rand(2, 8, 8).astype("float32")))[0].numpy()
# executor.run(fluid.default_startup_program())
# x = fluid.data(name="x", shape=[None, None, 8], dtype="float32")
# out, _ = lstm(x)
# out = executor.run(feed={"x": np.random.rand(2, 8, 8).astype("float32")}, fetch_list=[out.name])[0]
# print np.array(out)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SimNet reader
"""
import logging
import numpy as np
import io
class SimNetProcessor(object):
def __init__(self, args, vocab):
self.args = args
# load vocab
self.vocab = vocab
self.valid_label = np.array([])
self.test_label = np.array([])
def get_reader(self, mode, epoch=0):
"""
Get Reader
"""
def reader_with_pairwise():
"""
Reader with Pairwise
"""
if mode == "valid":
with io.open(self.args.valid_data_dir, "r",
encoding="utf8") as file:
for line in file:
query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(
label) == 0 or not label.isdigit() or int(
label) not in [0, 1]:
logging.warning(
"line not match format in test file")
continue
query = [
self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
if len(query) == 0:
query = [0]
if len(title) == 0:
title = [0]
yield [query, title]
elif mode == "test":
with io.open(self.args.test_data_dir, "r", encoding="utf8") as file:
for line in file:
query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(
label) == 0 or not label.isdigit() or int(
label) not in [0, 1]:
logging.warning(
"line not match format in test file")
continue
query = [
self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
if len(query) == 0:
query = [0]
if len(title) == 0:
title = [0]
# query = np.array([x.reshape(-1,1) for x in query]).astype('int64')
# title = np.array([x.reshape(-1,1) for x in title]).astype('int64')
yield [query, title]
else:
for idx in range(epoch):
with io.open(self.args.train_data_dir, "r",
encoding="utf8") as file:
for line in file:
query, pos_title, neg_title = line.strip().split("\t")
if len(query) == 0 or len(pos_title) == 0 or len(
neg_title) == 0:
logging.warning(
"line not match format in test file")
continue
query = [
self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
pos_title = [
self.vocab[word] for word in pos_title.split(" ")
if word in self.vocab
]
neg_title = [
self.vocab[word] for word in neg_title.split(" ")
if word in self.vocab
]
if len(query) == 0:
query = [0]
if len(pos_title) == 0:
pos_title = [0]
if len(neg_title) == 0:
neg_title = [0]
yield [query, pos_title, neg_title]
def reader_with_pointwise():
"""
Reader with Pointwise
"""
if mode == "valid":
with io.open(self.args.valid_data_dir, "r",
encoding="utf8") as file:
for line in file:
query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(
label) == 0 or not label.isdigit() or int(
label) not in [0, 1]:
logging.warning(
"line not match format in test file")
continue
query = [
self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
if len(query) == 0:
query = [0]
if len(title) == 0:
title = [0]
yield [query, title]
elif mode == "test":
with io.open(self.args.test_data_dir, "r", encoding="utf8") as file:
for line in file:
query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(
label) == 0 or not label.isdigit() or int(
label) not in [0, 1]:
logging.warning(
"line not match format in test file")
continue
query = [
self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
if len(query) == 0:
query = [0]
if len(title) == 0:
title = [0]
yield [query, title]
else:
for idx in range(epoch):
with io.open(self.args.train_data_dir, "r",
encoding="utf8") as file:
for line in file:
query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(
label) == 0 or not label.isdigit() or int(
label) not in [0, 1]:
logging.warning(
"line not match format in test file")
continue
query = [
self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
label = int(label)
if len(query) == 0:
query = [0]
if len(title) == 0:
title = [0]
yield [query, title, label]
if self.args.task_mode == "pairwise":
return reader_with_pairwise
else:
return reader_with_pointwise
def get_infer_reader(self):
"""
get infer reader
"""
with io.open(self.args.infer_data_dir, "r", encoding="utf8") as file:
for line in file:
query, title = line.strip().split("\t")
if len(query) == 0 or len(title) == 0:
logging.warning("line not match format in test file")
continue
query = [
self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
if len(query) == 0:
query = [0]
if len(title) == 0:
title = [0]
yield [query, title]
def get_infer_data(self):
"""
get infer data
"""
with io.open(self.args.infer_data_dir, "r", encoding="utf8") as file:
for line in file:
query, title = line.strip().split("\t")
if len(query) == 0 or len(title) == 0:
logging.warning("line not match format in test file")
continue
yield line.strip()
def get_valid_label(self):
"""
get valid data label
"""
if self.valid_label.size == 0:
labels = []
with io.open(self.args.valid_data_dir, "r", encoding="utf8") as f:
for line in f:
labels.append([int(line.strip().split("\t")[-1])])
self.valid_label = np.array(labels)
return self.valid_label
def get_test_label(self):
"""
get test data label
"""
if self.test_label.size == 0:
labels = []
with io.open(self.args.test_data_dir, "r", encoding="utf8") as f:
for line in f:
labels.append([int(line.strip().split("\t")[-1])])
self.test_label = np.array(labels)
return self.test_label
#!/usr/bin/env bash
export FLAGS_enable_parallel_graph=1
export FLAGS_sync_nccl_allreduce=1
export CUDA_VISIBLE_DEVICES=3
export FLAGS_fraction_of_gpu_memory_to_use=0.95
TASK_NAME='simnet'
TRAIN_DATA_PATH=./data/train_pairwise_data
VALID_DATA_PATH=./data/test_pairwise_data
TEST_DATA_PATH=./data/test_pairwise_data
INFER_DATA_PATH=./data/infer_data
VOCAB_PATH=./data/term2id.dict
CKPT_PATH=./model_files
TEST_RESULT_PATH=./test_result
INFER_RESULT_PATH=./infer_result
TASK_MODE='pairwise'
CONFIG_PATH=./config/bow_pairwise.json
INIT_CHECKPOINT=./model_files/simnet_bow_pairwise_pretrained_model/
# run_train
train() {
python run_classifier.py \
--task_name ${TASK_NAME} \
--use_cuda False \
--do_train True \
--do_valid True \
--do_test True \
--do_infer False \
--batch_size 128 \
--train_data_dir ${TRAIN_DATA_PATH} \
--valid_data_dir ${VALID_DATA_PATH} \
--test_data_dir ${TEST_DATA_PATH} \
--infer_data_dir ${INFER_DATA_PATH} \
--output_dir ${CKPT_PATH} \
--config_path ${CONFIG_PATH} \
--vocab_path ${VOCAB_PATH} \
--epoch 40 \
--save_steps 2000 \
--validation_steps 200 \
--compute_accuracy False \
--lamda 0.958 \
--task_mode ${TASK_MODE}\
--init_checkpoint ""
}
#run_evaluate
evaluate() {
python run_classifier.py \
--task_name ${TASK_NAME} \
--use_cuda false \
--do_test True \
--verbose_result True \
--batch_size 128 \
--test_data_dir ${TEST_DATA_PATH} \
--test_result_path ${TEST_RESULT_PATH} \
--config_path ${CONFIG_PATH} \
--vocab_path ${VOCAB_PATH} \
--task_mode ${TASK_MODE} \
--compute_accuracy False \
--lamda 0.958 \
--init_checkpoint ${INIT_CHECKPOINT}
}
# run_infer
infer() {
python run_classifier.py \
--task_name ${TASK_NAME} \
--use_cuda false \
--do_infer True \
--batch_size 128 \
--infer_data_dir ${INFER_DATA_PATH} \
--infer_result_path ${INFER_RESULT_PATH} \
--config_path ${CONFIG_PATH} \
--vocab_path ${VOCAB_PATH} \
--task_mode ${TASK_MODE} \
--init_checkpoint ${INIT_CHECKPOINT}
}
main() {
local cmd=${1:-help}
case "${cmd}" in
train)
train "$@";
;;
eval)
evaluate "$@";
;;
infer)
infer "$@";
;;
help)
echo "Usage: ${BASH_SOURCE} {train|eval|infer}";
return 0;
;;
*)
echo "Unsupport commend [${cmd}]";
echo "Usage: ${BASH_SOURCE} {train|eval|infer}";
return 1;
;;
esac
}
main "$@"
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SimNet Task
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import argparse
import multiprocessing
import sys
defaultencoding = 'utf-8'
if sys.getdefaultencoding() != defaultencoding:
reload(sys)
sys.setdefaultencoding(defaultencoding)
sys.path.append("..")
import paddle
import paddle.fluid as fluid
import numpy as np
import config
import utils
import reader
import nets.paddle_layers as layers
import io
import logging
from utils import ArgConfig
from model_check import check_version
from model_check import check_cuda
def train(conf_dict, args):
"""
train process
"""
# Get device
if args.use_cuda:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
# run train
logging.info("start train process ...")
def valid_and_test(pred_list, process, mode):
"""
return auc and acc
"""
pred_list = np.vstack(pred_list)
if mode == "test":
label_list = process.get_test_label()
elif mode == "valid":
label_list = process.get_valid_label()
if args.task_mode == "pairwise":
pred_list = (pred_list + 1) / 2
pred_list = np.hstack(
(np.ones_like(pred_list) - pred_list, pred_list))
metric.reset()
metric.update(pred_list, label_list)
auc = metric.eval()
if args.compute_accuracy:
acc = utils.get_accuracy(pred_list, label_list, args.task_mode,
args.lamda)
return auc, acc
else:
return auc
with fluid.dygraph.guard(place):
# used for continuous evaluation
if args.enable_ce:
SEED = 102
fluid.default_startup_program().random_seed = SEED
fluid.default_main_program().random_seed = SEED
# loading vocabulary
vocab = utils.load_vocab(args.vocab_path)
# get vocab size
conf_dict['dict_size'] = len(vocab)
# Load network structure dynamically
net = utils.import_class("./nets",
conf_dict["net"]["module_name"],
conf_dict["net"]["class_name"])(conf_dict)
if args.init_checkpoint is not "":
model, _ = fluid.dygraph.load_dygraph(args.init_checkpoint)
net.set_dict(model)
# Load loss function dynamically
loss = utils.import_class("./nets/losses",
conf_dict["loss"]["module_name"],
conf_dict["loss"]["class_name"])(conf_dict)
# Load Optimization method
learning_rate = conf_dict["optimizer"]["learning_rate"]
optimizer_name = conf_dict["optimizer"]["class_name"]
if optimizer_name=='SGDOptimizer':
optimizer = fluid.optimizer.SGDOptimizer(learning_rate,parameter_list=net.parameters())
elif optimizer_name=='AdamOptimizer':
beta1 = conf_dict["optimizer"]["beta1"]
beta2 = conf_dict["optimizer"]["beta2"]
epsilon = conf_dict["optimizer"]["epsilon"]
optimizer = fluid.optimizer.AdamOptimizer(
learning_rate,
beta1=beta1,
beta2=beta2,
epsilon=epsilon,
parameter_list=net.parameters())
# load auc method
metric = fluid.metrics.Auc(name="auc")
simnet_process = reader.SimNetProcessor(args, vocab)
# set global step
global_step = 0
ce_info = []
losses = []
start_time = time.time()
train_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False)
get_train_examples = simnet_process.get_reader("train",epoch=args.epoch)
train_pyreader.decorate_sample_list_generator(
paddle.batch(get_train_examples, batch_size=args.batch_size),
place)
if args.do_valid:
valid_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False)
get_valid_examples = simnet_process.get_reader("valid")
valid_pyreader.decorate_sample_list_generator(
paddle.batch(get_valid_examples, batch_size=args.batch_size),
place)
pred_list = []
if args.task_mode == "pairwise":
for left, pos_right, neg_right in train_pyreader():
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1])
neg_right = fluid.layers.reshape(neg_right, shape=[-1, 1])
net.train()
global_step += 1
left_feat, pos_score = net(left, pos_right)
pred = pos_score
_, neg_score = net(left, neg_right)
avg_cost = loss.compute(pos_score, neg_score)
losses.append(np.mean(avg_cost.numpy()))
avg_cost.backward()
optimizer.minimize(avg_cost)
net.clear_gradients()
if args.do_valid and global_step % args.validation_steps == 0:
for left, pos_right in valid_pyreader():
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1])
net.eval()
left_feat, pos_score = net(left, pos_right)
pred = pos_score
pred_list += list(pred.numpy())
valid_result = valid_and_test(pred_list, simnet_process, "valid")
if args.compute_accuracy:
valid_auc, valid_acc = valid_result
logging.info(
"global_steps: %d, valid_auc: %f, valid_acc: %f, valid_loss: %f" %
(global_step, valid_auc, valid_acc, np.mean(losses)))
else:
valid_auc = valid_result
logging.info("global_steps: %d, valid_auc: %f, valid_loss: %f" %
(global_step, valid_auc, np.mean(losses)))
if global_step % args.save_steps == 0:
model_save_dir = os.path.join(args.output_dir,
conf_dict["model_path"])
model_path = os.path.join(model_save_dir, str(global_step))
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
fluid.dygraph.save_dygraph(net.state_dict(), model_path)
logging.info("saving infer model in %s" % model_path)
else:
for left, right, label in train_pyreader():
left = fluid.layers.reshape(left, shape=[-1, 1])
right = fluid.layers.reshape(right, shape=[-1, 1])
label = fluid.layers.reshape(label, shape=[-1, 1])
net.train()
global_step += 1
left_feat, pred = net(left, right)
avg_cost = loss.compute(pred, label)
losses.append(np.mean(avg_cost.numpy()))
avg_cost.backward()
optimizer.minimize(avg_cost)
net.clear_gradients()
if args.do_valid and global_step % args.validation_steps == 0:
for left, right in valid_pyreader():
left = fluid.layers.reshape(left, shape=[-1, 1])
right = fluid.layers.reshape(right, shape=[-1, 1])
net.eval()
left_feat, pred = net(left, right)
pred_list += list(pred.numpy())
valid_result = valid_and_test(pred_list, simnet_process, "valid")
if args.compute_accuracy:
valid_auc, valid_acc = valid_result
logging.info(
"global_steps: %d, valid_auc: %f, valid_acc: %f, valid_loss: %f" %
(global_step, valid_auc, valid_acc, np.mean(losses)))
else:
valid_auc = valid_result
logging.info("global_steps: %d, valid_auc: %f, valid_loss: %f" %
(global_step, valid_auc, np.mean(losses)))
if global_step % args.save_steps == 0:
model_save_dir = os.path.join(args.output_dir,
conf_dict["model_path"])
model_path = os.path.join(model_save_dir, str(global_step))
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
fluid.dygraph.save_dygraph(net.state_dict(), model_path)
logging.info("saving infer model in %s" % model_path)
end_time = time.time()
ce_info.append([np.mean(losses), end_time - start_time])
# final save
logging.info("the final step is %s" % global_step)
model_save_dir = os.path.join(args.output_dir,
conf_dict["model_path"])
model_path = os.path.join(model_save_dir, str(global_step))
if not os.path.exists(model_save_dir):
os.makedirs(model_save_dir)
fluid.dygraph.save_dygraph(net.state_dict(), model_path)
logging.info("saving infer model in %s" % model_path)
# used for continuous evaluation
if args.enable_ce:
# if True:
card_num = get_cards()
ce_loss = 0
ce_time = 0
try:
ce_loss = ce_info[-1][0]
ce_time = ce_info[-1][1]
except:
logging.info("ce info err!")
print("kpis\teach_step_duration_%s_card%s\t%s" %
(args.task_name, card_num, ce_time))
print("kpis\ttrain_loss_%s_card%s\t%f" %
(args.task_name, card_num, ce_loss))
if args.do_test:
# Get Feeder and Reader
test_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False)
get_test_examples = simnet_process.get_reader("test")
test_pyreader.decorate_sample_list_generator(
paddle.batch(get_test_examples, batch_size=args.batch_size),
place)
pred_list = []
for left, pos_right in test_pyreader():
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1])
net.eval()
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1])
left_feat, pos_score = net(left, pos_right)
pred = pos_score
pred_list += list(pred.numpy())
test_result = valid_and_test(pred_list, simnet_process, "test")
if args.compute_accuracy:
test_auc, test_acc = test_result
logging.info("AUC of test is %f, Accuracy of test is %f" %
(test_auc, test_acc))
else:
test_auc = test_result
logging.info("AUC of test is %f" % test_auc)
def test(conf_dict, args):
"""
Evaluation Function
"""
logging.info("start test process ...")
if args.use_cuda:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
vocab = utils.load_vocab(args.vocab_path)
simnet_process = reader.SimNetProcessor(args, vocab)
test_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False)
get_test_examples = simnet_process.get_reader("test")
test_pyreader.decorate_sample_list_generator(
paddle.batch(get_test_examples, batch_size=args.batch_size),
place)
conf_dict['dict_size'] = len(vocab)
net = utils.import_class("./nets",
conf_dict["net"]["module_name"],
conf_dict["net"]["class_name"])(conf_dict)
model, _ = fluid.dygraph.load_dygraph(args.init_checkpoint)
net.set_dict(model)
metric = fluid.metrics.Auc(name="auc")
pred_list = []
with io.open("predictions.txt", "w", encoding="utf8") as predictions_file:
if args.task_mode == "pairwise":
for left, pos_right in test_pyreader():
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1])
left_feat, pos_score = net(left, pos_right)
pred = pos_score
# pred_list += list(pred.numpy())
pred_list += list(map(lambda item: float(item[0]), pred.numpy()[0]))
predictions_file.write(u"\n".join(
map(lambda item: str((item[0] + 1) / 2), pred.numpy()[0])) + "\n")
else:
for left, right in test_pyreader():
left = fluid.layers.reshape(left, shape=[-1, 1])
right = fluid.layers.reshape(right, shape=[-1, 1])
left_feat, pred = net(left, right)
# pred_list += list(pred.numpy())
pred_list += list(map(lambda item: float(item[0]), pred.numpy()[0]))
predictions_file.write(u"\n".join(
map(lambda item: str(np.argmax(item)), pred.numpy()[0])) + "\n")
if args.task_mode == "pairwise":
pred_list = np.array(pred_list).reshape((-1, 1))
pred_list = (pred_list + 1) / 2
pred_list = np.hstack(
(np.ones_like(pred_list) - pred_list, pred_list))
else:
pred_list = np.array(pred_list)
labels = simnet_process.get_test_label()
metric.update(pred_list, labels)
if args.compute_accuracy:
acc = utils.get_accuracy(pred_list, labels, args.task_mode,
args.lamda)
logging.info("AUC of test is %f, Accuracy of test is %f" %
(metric.eval(), acc))
else:
logging.info("AUC of test is %f" % metric.eval())
if args.verbose_result:
utils.get_result_file(args)
logging.info("test result saved in %s" %
os.path.join(os.getcwd(), args.test_result_path))
def infer(conf_dict, args):
"""
run predict
"""
logging.info("start test process ...")
if args.use_cuda:
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()
vocab = utils.load_vocab(args.vocab_path)
simnet_process = reader.SimNetProcessor(args, vocab)
get_infer_examples = simnet_process.get_infer_reader
infer_pyreader = fluid.io.PyReader(capacity=16, return_list=True, use_double_buffer=False)
infer_pyreader.decorate_sample_list_generator(
paddle.batch(get_infer_examples, batch_size=args.batch_size),
place)
conf_dict['dict_size'] = len(vocab)
net = utils.import_class("./nets",
conf_dict["net"]["module_name"],
conf_dict["net"]["class_name"])(conf_dict)
model, _ = fluid.dygraph.load_dygraph(args.init_checkpoint)
net.set_dict(model)
pred_list = []
if args.task_mode == "pairwise":
for left, pos_right in infer_pyreader():
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(pos_right, shape=[-1, 1])
left_feat, pos_score = net(left, pos_right)
pred = pos_score
preds_list += list(
map(lambda item: str((item[0] + 1) / 2), pred.numpy()[0]))
else:
for left, right in infer_pyreader():
left = fluid.layers.reshape(left, shape=[-1, 1])
pos_right = fluid.layers.reshape(right, shape=[-1, 1])
left_feat, pred = net(left, right)
preds_list += map(lambda item: str(np.argmax(item)), pred.numpy()[0])
with io.open(args.infer_result_path, "w", encoding="utf8") as infer_file:
for _data, _pred in zip(simnet_process.get_infer_data(), preds_list):
infer_file.write(_data + "\t" + _pred + "\n")
logging.info("infer result saved in %s" %
os.path.join(os.getcwd(), args.infer_result_path))
def get_cards():
num = 0
cards = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if cards != '':
num = len(cards.split(","))
return num
if __name__ == "__main__":
args = ArgConfig()
args = args.build_conf()
utils.print_arguments(args)
check_cuda(args.use_cuda)
check_version()
utils.init_log("./log/TextSimilarityNet")
conf_dict = config.SimNetConfig(args)
if args.do_train:
train(conf_dict, args)
elif args.do_test:
test(conf_dict, args)
elif args.do_infer:
infer(conf_dict, args)
else:
raise ValueError(
"one of do_train and do_test and do_infer must be True")
# -*- encoding:utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
SimNet utilities.
"""
import argparse
import time
import sys
import re
import os
import six
import numpy as np
import logging
import logging.handlers
import paddle.fluid as fluid
import io
"""
******functions for file processing******
"""
def load_vocab(file_path):
"""
load the given vocabulary
"""
vocab = {}
f = io.open(file_path, "r", encoding="utf8")
for line in f:
items = line.strip("\n").split("\t")
if items[0] not in vocab:
vocab[items[0]] = int(items[1])
vocab["<unk>"] = 0
return vocab
def get_result_file(args):
"""
Get Result File
Args:
conf_dict: Input path config
samples_file_path: Data path of real training
predictions_file_path: Prediction results path
Returns:
result_file: merge sample and predict result
"""
with io.open(args.test_data_dir, "r", encoding="utf8") as test_file:
with io.open("predictions.txt", "r", encoding="utf8") as predictions_file:
with io.open(args.test_result_path, "w", encoding="utf8") as test_result_file:
test_datas = [line.strip("\n") for line in test_file]
predictions = [line.strip("\n") for line in predictions_file]
for test_data, prediction in zip(test_datas, predictions):
test_result_file.write(test_data + "\t" + prediction + "\n")
os.remove("predictions.txt")
"""
******functions for string processing******
"""
def pattern_match(pattern, line):
"""
Check whether a string is matched
Args:
pattern: mathing pattern
line : input string
Returns:
True/False
"""
if re.match(pattern, line):
return True
else:
return False
"""
******functions for parameter processing******
"""
def print_progress(task_name, percentage, style=0):
"""
Print progress bar
Args:
task_name: The name of the current task
percentage: Current progress
style: Progress bar form
"""
styles = ['#', '█']
mark = styles[style] * percentage
mark += ' ' * (100 - percentage)
status = '%d%%' % percentage if percentage < 100 else 'Finished'
sys.stdout.write('%+20s [%s] %s\r' % (task_name, mark, status))
sys.stdout.flush()
time.sleep(0.002)
def display_args(name, args):
"""
Print parameter information
Args:
name: logger instance name
args: Input parameter dictionary
"""
logger = logging.getLogger(name)
logger.info("The arguments passed by command line is :")
for k, v in sorted(v for v in vars(args).items()):
logger.info("{}:\t{}".format(k, v))
def import_class(module_path, module_name, class_name):
"""
Load class dynamically
Args:
module_path: The current path of the module
module_name: The module name
class_name: The name of class in the import module
Return:
Return the attribute value of the class object
"""
if module_path:
sys.path.append(module_path)
module = __import__(module_name)
return getattr(module, class_name)
def str2bool(v):
"""
String to Boolean
"""
# because argparse does not support to parse "true, False" as python
# boolean directly
return v.lower() in ("true", "t", "1")
class ArgumentGroup(object):
"""
Argument Class
"""
def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, **kwargs):
"""
Add argument
"""
type = str2bool if type == bool else type
self._group.add_argument(
"--" + name,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
class ArgConfig(object):
def __init__(self):
parser = argparse.ArgumentParser()
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("config_path", str, None, "Path to the json file for EmoTect model config.")
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.")
model_g.add_arg("output_dir", str, None, "Directory path to save checkpoints")
model_g.add_arg("task_mode", str, None, "task mode: pairwise or pointwise")
train_g = ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 10, "Number of epoches for training.")
train_g.add_arg("save_steps", int, 200, "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 100, "The steps interval to evaluate model performance.")
log_g = ArgumentGroup(parser, "logging", "logging related")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
log_g.add_arg("verbose_result", bool, True, "Whether to output verbose result.")
log_g.add_arg("test_result_path", str, "test_result", "Directory path to test result.")
log_g.add_arg("infer_result_path", str, "infer_result", "Directory path to infer result.")
data_g = ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("train_data_dir", str, None, "Directory path to training data.")
data_g.add_arg("valid_data_dir", str, None, "Directory path to valid data.")
data_g.add_arg("test_data_dir", str, None, "Directory path to testing data.")
data_g.add_arg("infer_data_dir", str, None, "Directory path to infer data.")
data_g.add_arg("vocab_path", str, None, "Vocabulary path.")
data_g.add_arg("batch_size", int, 32, "Total examples' number in batch for training.")
run_type_g = ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.")
run_type_g.add_arg("task_name", str, None, "The name of task to perform sentiment classification.")
run_type_g.add_arg("do_train", bool, False, "Whether to perform training.")
run_type_g.add_arg("do_valid", bool, False, "Whether to perform dev.")
run_type_g.add_arg("do_test", bool, False, "Whether to perform testing.")
run_type_g.add_arg("do_infer", bool, False, "Whether to perform inference.")
run_type_g.add_arg("compute_accuracy", bool, False, "Whether to compute accuracy.")
run_type_g.add_arg("lamda", float, 0.91, "When task_mode is pairwise, lamda is the threshold for calculating the accuracy.")
custom_g = ArgumentGroup(parser, "customize", "customized options.")
self.custom_g = custom_g
parser.add_argument('--enable_ce',action='store_true',help='If set, run the task with continuous evaluation logs.')
self.parser = parser
def add_arg(self, name, dtype, default, descrip):
self.custom_g.add_arg(name, dtype, default, descrip)
def build_conf(self):
return self.parser.parse_args()
def print_arguments(args):
"""
Print Arguments
"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def init_log(
log_path,
level=logging.INFO,
when="D",
backup=7,
format="%(levelname)s: %(asctime)s - %(filename)s:%(lineno)d * %(thread)d %(message)s",
datefmt=None):
"""
init_log - initialize log module
Args:
log_path - Log file path prefix.
Log data will go to two files: log_path.log and log_path.log.wf
Any non-exist parent directories will be created automatically
level - msg above the level will be displayed
DEBUG < INFO < WARNING < ERROR < CRITICAL
the default value is logging.INFO
when - how to split the log file by time interval
'S' : Seconds
'M' : Minutes
'H' : Hours
'D' : Days
'W' : Week day
default value: 'D'
format - format of the log
default format:
%(levelname)s: %(asctime)s: %(filename)s:%(lineno)d * %(thread)d %(message)s
INFO: 12-09 18:02:42: log.py:40 * 139814749787872 HELLO WORLD
backup - how many backup file to keep
default value: 7
Raises:
OSError: fail to create log directories
IOError: fail to open log file
"""
formatter = logging.Formatter(format, datefmt)
logger = logging.getLogger()
logger.setLevel(level)
# console Handler
consoleHandler = logging.StreamHandler()
consoleHandler.setLevel(logging.DEBUG)
logger.addHandler(consoleHandler)
dir = os.path.dirname(log_path)
if not os.path.isdir(dir):
os.makedirs(dir)
handler = logging.handlers.TimedRotatingFileHandler(
log_path + ".log", when=when, backupCount=backup)
handler.setLevel(level)
handler.setFormatter(formatter)
logger.addHandler(handler)
handler = logging.handlers.TimedRotatingFileHandler(
log_path + ".log.wf", when=when, backupCount=backup)
handler.setLevel(logging.WARNING)
handler.setFormatter(formatter)
logger.addHandler(handler)
def set_level(level):
"""
Reak-time set log level
"""
logger = logging.getLogger()
logger.setLevel(level)
logging.info('log level is set to : %d' % level)
def get_level():
"""
get Real-time log level
"""
logger = logging.getLogger()
return logger.level
def get_accuracy(preds, labels, mode, lamda=0.958):
"""
compute accuracy
"""
if mode == "pairwise":
preds = np.array(list(map(lambda x: 1 if x[1] >= lamda else 0, preds)))
else:
preds = np.array(list(map(lambda x: np.argmax(x), preds)))
labels = np.squeeze(labels)
return np.mean(preds == labels)
def get_softmax(preds):
"""
compute sotfmax
"""
_exp = np.exp(preds)
return _exp / np.sum(_exp, axis=1, keepdims=True)
def get_sigmoid(preds):
"""
compute sigmoid
"""
return 1 / (1 + np.exp(-preds))
def deal_preds_of_mmdnn(conf_dict, preds):
"""
deal preds of mmdnn
"""
if conf_dict['task_mode'] == 'pairwise':
return get_sigmoid(preds)
else:
return get_softmax(preds)
def init_checkpoint(exe, init_checkpoint_path, main_program):
"""
init checkpoint
"""
assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var):
if not fluid.io.is_persistable(var):
return False
return os.path.exists(os.path.join(init_checkpoint_path, var.name))
fluid.io.load_vars(
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册