diff --git a/models/recall/gru4rec/README.md b/models/recall/gru4rec/README.md new file mode 100644 index 0000000000000000000000000000000000000000..dc44dbb4848f75607c6f3ce6b412ac5100095ba0 --- /dev/null +++ b/models/recall/gru4rec/README.md @@ -0,0 +1,206 @@ +# GRU4REC + +以下是本例的简要目录结构及说明: + +``` +├── data #样例数据及数据处理相关文件 + ├── train + ├── small_train.txt # 样例训练数据 + ├── test + ├── small_test.txt # 样例测试数据 + ├── convert_format.py # 数据转换脚本 + ├── download.py # 数据下载脚本 + ├── preprocess.py # 数据预处理脚本 + ├── text2paddle.py # paddle训练数据生成脚本 +├── __init__.py +├── README.md # 文档 +├── model.py #模型文件 +├── config.yaml #配置文件 +├── data_prepare.sh #一键数据处理脚本 +├── rsc15_reader.py #reader +``` + +注:在阅读该示例前,建议您先了解以下内容: + +[paddlerec入门教程](https://github.com/PaddlePaddle/PaddleRec/blob/master/README.md) + + +--- +## 内容 + +- [模型简介](#模型简介) +- [数据准备](#数据准备) +- [运行环境](#运行环境) +- [快速开始](#快速开始) +- [论文复现](#论文复现) +- [进阶使用](#进阶使用) +- [FAQ](#FAQ) + +## 模型简介 +GRU4REC模型的介绍可以参阅论文[Session-based Recommendations with Recurrent Neural Networks](https://arxiv.org/abs/1511.06939)。 + +论文的贡献在于首次将RNN(GRU)运用于session-based推荐,相比传统的KNN和矩阵分解,效果有明显的提升。 + +论文的核心思想是在一个session中,用户点击一系列item的行为看做一个序列,用来训练RNN模型。预测阶段,给定已知的点击序列作为输入,预测下一个可能点击的item。 + +session-based推荐应用场景非常广泛,比如用户的商品浏览、新闻点击、地点签到等序列数据。 + +本模型配置默认使用demo数据集,若进行精度验证,请参考[论文复现](#论文复现)部分。 + +本项目支持功能 + +训练:单机CPU、单机单卡GPU、本地模拟参数服务器训练、增量训练,配置请参考 [启动训练](https://github.com/PaddlePaddle/PaddleRec/blob/master/doc/train.md) + +预测:单机CPU、单机单卡GPU;配置请参考[PaddleRec 离线预测](https://github.com/PaddlePaddle/PaddleRec/blob/master/doc/predict.md) + +## 数据处理 +本示例中数据处理共包含三步: +- Step1: 原始数据数据集下载 +``` +cd data/ +python download.py +``` +- Step2: 数据预处理及格式转换。 + 1. 以session_id为key合并原始数据集,得到每个session的日期,及顺序点击列表。 + 2. 过滤掉长度为1的session;过滤掉点击次数小于5的items。 + 3. 训练集、测试集划分。原始数据集里最新日期七天内的作为训练集,更早之前的数据作为测试集。 +``` +python preprocess.py +python convert_format.py +``` +这一步之后,会在data/目录下得到两个文件,rsc15_train_tr_paddle.txt为原始训练文件,rsc15_test_paddle.txt为原始测试文件。格式如下所示: +``` +214536502 214536500 214536506 214577561 +214662742 214662742 214825110 214757390 214757407 214551617 +214716935 214774687 214832672 +214836765 214706482 +214701242 214826623 +214826835 214826715 +214838855 214838855 +214576500 214576500 214576500 +214821275 214821275 214821371 214821371 214821371 214717089 214563337 214706462 214717436 214743335 214826837 214819762 +214717867 21471786 +``` +- Step3: 生成字典并整理数据路径。这一步会根据训练和测试文件生成字典和对应的paddle输入文件,并将训练文件统一放在data/all_train目录下,测试文件统一放在data/all_test目录下。 +``` +mkdir raw_train_data && mkdir raw_test_data +mv rsc15_train_tr_paddle.txt raw_train_data/ && mv rsc15_test_paddle.txt raw_test_data/ +mkdir all_train && mkdir all_test + +python text2paddle.py raw_train_data/ raw_test_data/ all_train all_test vocab.txt +``` + +方便起见,我们提供了一键式数据生成脚本: +``` +sh data_prepare.sh +``` + +## 运行环境 + +PaddlePaddle>=1.7.2 + +python 2.7/3.5/3.6/3.7 + +PaddleRec >=0.1 + +os : windows/linux/macos + +## 快速开始 + +### 单机训练 + +在config.yaml文件中设置好设备,epochs等。 +``` +runner: +- name: cpu_train_runner + class: train + device: cpu # gpu + epochs: 10 + save_checkpoint_interval: 1 + save_inference_interval: 1 + save_checkpoint_path: "increment_gru4rec" + save_inference_path: "inference_gru4rec" + save_inference_feed_varnames: ["src_wordseq", "dst_wordseq"] # feed vars of save inference + save_inference_fetch_varnames: ["mean_0.tmp_0", "top_k_0.tmp_0"] + print_interval: 10 + phases: [train] + +``` + +### 单机预测 + +在config.yaml文件中设置好设备,epochs等。 +``` +- name: cpu_infer_runner + class: infer + init_model_path: "increment_gru4rec" + device: cpu # gpu + phases: [infer] +``` + +### 运行 +``` +python -m paddlerec.run -m paddlerec.models.recall.gru4rec +``` + +### 结果展示 + +样例数据训练结果展示: + +``` +Running SingleStartup. +Running SingleRunner. +2020-09-22 03:31:18,167-INFO: [Train], epoch: 0, batch: 10, time_each_interval: 4.34s, RecallCnt: [1669.], cost: [8.366313], InsCnt: [16228.], Acc(Recall@20): [0.10284693] +2020-09-22 03:31:21,982-INFO: [Train], epoch: 0, batch: 20, time_each_interval: 3.82s, RecallCnt: [3168.], cost: [8.170701], InsCnt: [31943.], Acc(Recall@20): [0.09917666] +2020-09-22 03:31:25,797-INFO: [Train], epoch: 0, batch: 30, time_each_interval: 3.81s, RecallCnt: [4855.], cost: [8.017181], InsCnt: [47892.], Acc(Recall@20): [0.10137393] +... +epoch 0 done, use time: 6003.78719687, global metrics: cost=[4.4394927], InsCnt=23622448.0 RecallCnt=14547467.0 Acc(Recall@20)=0.6158323218660487 +2020-09-22 05:11:17,761-INFO: save epoch_id:0 model into: "inference_gru4rec/0" +... +epoch 9 done, use time: 6009.97707605, global metrics: cost=[4.069373], InsCnt=236237470.0 RecallCnt=162838200.0 Acc(Recall@20)=0.6892988086157644 +2020-09-22 20:17:11,358-INFO: save epoch_id:9 model into: "inference_gru4rec/9" +PaddleRec Finish +``` + +样例数据预测结果展示: +``` +Running SingleInferStartup. +Running SingleInferRunner. +load persistables from increment_gru4rec/9 +2020-09-23 03:46:21,081-INFO: [Infer] batch: 20, time_each_interval: 3.68s, RecallCnt: [24875.], InsCnt: [35581.], Acc(Recall@20): [0.6991091] +Infer infer of epoch 9 done, use time: 5.25408315659, global metrics: InsCnt=52551.0 RecallCnt=36720.0 Acc(Recall@20)=0.698749785922247 +... +Infer infer of epoch 0 done, use time: 5.20699501038, global metrics: InsCnt=52551.0 RecallCnt=33664.0 Acc(Recall@20)=0.6405967536298073 +PaddleRec Finish +``` + +## 论文复现 + +用原论文的完整数据复现论文效果需要在config.yaml修改超参: +- batch_size: 修改config.yaml中dataset_train数据集的batch_size为500。 +- epochs: 修改config.yaml中runner的epochs为10。 +- 数据源:修改config.yaml中dataset_train数据集的data_path为"{workspace}/data/all_train",dataset_test数据集的data_path为"{workspace}/data/all_test"。 + +使用gpu训练10轮 测试结果为 + +epoch | 测试recall@20 | 速度(s) +-- | -- | -- +1 | 0.6406 | 6003 +2 | 0.6727 | 6007 +3 | 0.6831 | 6108 +4 | 0.6885 | 6025 +5 | 0.6913 | 6019 +6 | 0.6931 | 6011 +7 | 0.6952 | 6015 +8 | 0.6968 | 6076 +9 | 0.6972 | 6076 +10 | 0.6987| 6009 + +修改后运行方案:修改config.yaml中的'workspace'为config.yaml的目录位置,执行 +``` +python -m paddlerec.run -m /home/your/dir/config.yaml #调试模式 直接指定本地config的绝对路径 +``` + +## 进阶使用 + +## FAQ diff --git a/models/recall/gru4rec/config.yaml b/models/recall/gru4rec/config.yaml index 09c4217a67f24126bc74e3f600415b7fcd459cba..67ac5f7692fcc94e4b247fe5fb1ae2290c3018ff 100644 --- a/models/recall/gru4rec/config.yaml +++ b/models/recall/gru4rec/config.yaml @@ -16,18 +16,19 @@ workspace: "models/recall/gru4rec" dataset: - name: dataset_train - batch_size: 5 - type: QueueDataset + batch_size: 500 + type: DataLoader # QueueDataset data_path: "{workspace}/data/train" data_converter: "{workspace}/rsc15_reader.py" - name: dataset_infer - batch_size: 5 - type: QueueDataset + batch_size: 500 + type: DataLoader #QueueDataset data_path: "{workspace}/data/test" data_converter: "{workspace}/rsc15_reader.py" hyper_parameters: - vocab_size: 1000 + recall_k: 20 + vocab_size: 37483 hid_size: 100 emb_lr_x: 10.0 gru_lr_x: 1.0 @@ -40,30 +41,34 @@ hyper_parameters: strategy: async #use infer_runner mode and modify 'phase' below if infer -mode: train_runner +mode: [cpu_train_runner, cpu_infer_runner] #mode: infer_runner runner: -- name: train_runner +- name: cpu_train_runner class: train device: cpu - epochs: 3 - save_checkpoint_interval: 2 - save_inference_interval: 4 - save_checkpoint_path: "increment" - save_inference_path: "inference" + epochs: 10 + save_checkpoint_interval: 1 + save_inference_interval: 1 + save_checkpoint_path: "increment_gru4rec" + save_inference_path: "inference_gru4rec" + save_inference_feed_varnames: ["src_wordseq", "dst_wordseq"] # feed vars of save inference + save_inference_fetch_varnames: ["mean_0.tmp_0", "top_k_0.tmp_0"] print_interval: 10 -- name: infer_runner + phases: [train] +- name: cpu_infer_runner class: infer - init_model_path: "increment/0" + init_model_path: "increment_gru4rec" device: cpu + phases: [infer] phase: - name: train model: "{workspace}/model.py" dataset_name: dataset_train thread_num: 1 -#- name: infer -# model: "{workspace}/model.py" -# dataset_name: dataset_infer -# thread_num: 1 +- name: infer + model: "{workspace}/model.py" + dataset_name: dataset_infer + thread_num: 1 diff --git a/models/recall/gru4rec/data/convert_format.py b/models/recall/gru4rec/data/convert_format.py new file mode 100644 index 0000000000000000000000000000000000000000..9a6867e623c1a783307f7978015bde17b0692b1a --- /dev/null +++ b/models/recall/gru4rec/data/convert_format.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import codecs + + +def convert_format(input, output): + with codecs.open(input, "r", encoding='utf-8') as rf: + with codecs.open(output, "w", encoding='utf-8') as wf: + last_sess = -1 + sign = 1 + i = 0 + for l in rf: + i = i + 1 + if i == 1: + continue + if (i % 1000000 == 1): + print(i) + tokens = l.strip().split() + if (int(tokens[0]) != last_sess): + if (sign): + sign = 0 + wf.write(tokens[1] + " ") + else: + wf.write("\n" + tokens[1] + " ") + last_sess = int(tokens[0]) + else: + wf.write(tokens[1] + " ") + + +input = "rsc15_train_tr.txt" +output = "rsc15_train_tr_paddle.txt" +input2 = "rsc15_test.txt" +output2 = "rsc15_test_paddle.txt" +convert_format(input, output) +convert_format(input2, output2) diff --git a/models/recall/gru4rec/data/download.py b/models/recall/gru4rec/data/download.py new file mode 100644 index 0000000000000000000000000000000000000000..b0e0979ee72bd9e6c2e24746577df970c989b6d0 --- /dev/null +++ b/models/recall/gru4rec/data/download.py @@ -0,0 +1,61 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import requests +import sys +import time +import os + +lasttime = time.time() +FLUSH_INTERVAL = 0.1 + + +def progress(str, end=False): + global lasttime + if end: + str += "\n" + lasttime = 0 + if time.time() - lasttime >= FLUSH_INTERVAL: + sys.stdout.write("\r%s" % str) + lasttime = time.time() + sys.stdout.flush() + + +def _download_file(url, savepath, print_progress): + r = requests.get(url, stream=True) + total_length = r.headers.get('content-length') + + if total_length is None: + with open(savepath, 'wb') as f: + shutil.copyfileobj(r.raw, f) + else: + with open(savepath, 'wb') as f: + dl = 0 + total_length = int(total_length) + starttime = time.time() + if print_progress: + print("Downloading %s" % os.path.basename(savepath)) + for data in r.iter_content(chunk_size=4096): + dl += len(data) + f.write(data) + if print_progress: + done = int(50 * dl / total_length) + progress("[%-50s] %.2f%%" % + ('=' * done, float(100 * dl) / total_length)) + if print_progress: + progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) + + +_download_file("https://paddlerec.bj.bcebos.com/gnn%2Fyoochoose-clicks.dat", + "./yoochoose-clicks.dat", True) diff --git a/models/recall/gru4rec/data/preprocess.py b/models/recall/gru4rec/data/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..233237265ec5194bd30fc61fcfbefe189d3d8162 --- /dev/null +++ b/models/recall/gru4rec/data/preprocess.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +""" +Created on Fri Jun 25 16:20:12 2015 + +@author: Balázs Hidasi +""" + +import numpy as np +import pandas as pd +import datetime as dt +import time + +PATH_TO_ORIGINAL_DATA = './' +PATH_TO_PROCESSED_DATA = './' + +data = pd.read_csv( + PATH_TO_ORIGINAL_DATA + 'yoochoose-clicks.dat', + sep=',', + header=0, + usecols=[0, 1, 2], + dtype={0: np.int32, + 1: str, + 2: np.int64}) +data.columns = ['session_id', 'timestamp', 'item_id'] +data['Time'] = data.timestamp.apply(lambda x: time.mktime(dt.datetime.strptime(x, '%Y-%m-%dT%H:%M:%S.%fZ').timetuple())) #This is not UTC. It does not really matter. +del (data['timestamp']) + +session_lengths = data.groupby('session_id').size() +data = data[np.in1d(data.session_id, session_lengths[session_lengths > 1] + .index)] +item_supports = data.groupby('item_id').size() +data = data[np.in1d(data.item_id, item_supports[item_supports >= 5].index)] +session_lengths = data.groupby('session_id').size() +data = data[np.in1d(data.session_id, session_lengths[session_lengths >= 2] + .index)] + +tmax = data.Time.max() +session_max_times = data.groupby('session_id').Time.max() +session_train = session_max_times[session_max_times < tmax - 86400].index +session_test = session_max_times[session_max_times >= tmax - 86400].index +train = data[np.in1d(data.session_id, session_train)] +test = data[np.in1d(data.session_id, session_test)] +test = test[np.in1d(test.item_id, train.item_id)] +tslength = test.groupby('session_id').size() +test = test[np.in1d(test.session_id, tslength[tslength >= 2].index)] +print('Full train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format( + len(train), train.session_id.nunique(), train.item_id.nunique())) +train.to_csv( + PATH_TO_PROCESSED_DATA + 'rsc15_train_full.txt', sep='\t', index=False) +print('Test set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format( + len(test), test.session_id.nunique(), test.item_id.nunique())) +test.to_csv(PATH_TO_PROCESSED_DATA + 'rsc15_test.txt', sep='\t', index=False) + +tmax = train.Time.max() +session_max_times = train.groupby('session_id').Time.max() +session_train = session_max_times[session_max_times < tmax - 86400].index +session_valid = session_max_times[session_max_times >= tmax - 86400].index +train_tr = train[np.in1d(train.session_id, session_train)] +valid = train[np.in1d(train.session_id, session_valid)] +valid = valid[np.in1d(valid.item_id, train_tr.item_id)] +tslength = valid.groupby('session_id').size() +valid = valid[np.in1d(valid.session_id, tslength[tslength >= 2].index)] +print('Train set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format( + len(train_tr), train_tr.session_id.nunique(), train_tr.item_id.nunique())) +train_tr.to_csv( + PATH_TO_PROCESSED_DATA + 'rsc15_train_tr.txt', sep='\t', index=False) +print('Validation set\n\tEvents: {}\n\tSessions: {}\n\tItems: {}'.format( + len(valid), valid.session_id.nunique(), valid.item_id.nunique())) +valid.to_csv( + PATH_TO_PROCESSED_DATA + 'rsc15_train_valid.txt', sep='\t', index=False) diff --git a/models/recall/gru4rec/data/text2paddle.py b/models/recall/gru4rec/data/text2paddle.py new file mode 100644 index 0000000000000000000000000000000000000000..ff952825944d68edd2f998087aea6cd9c725e9b5 --- /dev/null +++ b/models/recall/gru4rec/data/text2paddle.py @@ -0,0 +1,115 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import six +import collections +import os +import sys +import io +if six.PY2: + reload(sys) + sys.setdefaultencoding('utf-8') + + +def word_count(input_file, word_freq=None): + """ + compute word count from corpus + """ + if word_freq is None: + word_freq = collections.defaultdict(int) + + for l in input_file: + for w in l.strip().split(): + word_freq[w] += 1 + + return word_freq + + +def build_dict(min_word_freq=0, train_dir="", test_dir=""): + """ + Build a word dictionary from the corpus, Keys of the dictionary are words, + and values are zero-based IDs of these words. + """ + word_freq = collections.defaultdict(int) + files = os.listdir(train_dir) + for fi in files: + with io.open(os.path.join(train_dir, fi), "r") as f: + word_freq = word_count(f, word_freq) + files = os.listdir(test_dir) + for fi in files: + with io.open(os.path.join(test_dir, fi), "r") as f: + word_freq = word_count(f, word_freq) + + word_freq = [x for x in six.iteritems(word_freq) if x[1] > min_word_freq] + word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) + words, _ = list(zip(*word_freq_sorted)) + word_idx = dict(list(zip(words, six.moves.range(len(words))))) + return word_idx + + +def write_paddle(word_idx, train_dir, test_dir, output_train_dir, + output_test_dir): + files = os.listdir(train_dir) + if not os.path.exists(output_train_dir): + os.mkdir(output_train_dir) + for fi in files: + with io.open(os.path.join(train_dir, fi), "r") as f: + with io.open(os.path.join(output_train_dir, fi), "w") as wf: + for l in f: + l = l.strip().split() + l = [word_idx.get(w) for w in l] + for w in l: + wf.write(str2file(str(w) + " ")) + wf.write(str2file("\n")) + + files = os.listdir(test_dir) + if not os.path.exists(output_test_dir): + os.mkdir(output_test_dir) + for fi in files: + with io.open(os.path.join(test_dir, fi), "r", encoding='utf-8') as f: + with io.open( + os.path.join(output_test_dir, fi), "w", + encoding='utf-8') as wf: + for l in f: + l = l.strip().split() + l = [word_idx.get(w) for w in l] + for w in l: + wf.write(str2file(str(w) + " ")) + wf.write(str2file("\n")) + + +def str2file(str): + if six.PY2: + return str.decode("utf-8") + else: + return str + + +def text2paddle(train_dir, test_dir, output_train_dir, output_test_dir, + output_vocab): + vocab = build_dict(0, train_dir, test_dir) + print("vocab size:", str(len(vocab))) + with io.open(output_vocab, "w", encoding='utf-8') as wf: + wf.write(str2file(str(len(vocab)) + "\n")) + write_paddle(vocab, train_dir, test_dir, output_train_dir, output_test_dir) + + +train_dir = sys.argv[1] +test_dir = sys.argv[2] +output_train_dir = sys.argv[3] +output_test_dir = sys.argv[4] +output_vocab = sys.argv[5] +text2paddle(train_dir, test_dir, output_train_dir, output_test_dir, + output_vocab) diff --git a/models/recall/gru4rec/data_prepare.sh b/models/recall/gru4rec/data_prepare.sh new file mode 100644 index 0000000000000000000000000000000000000000..f3dc2b3f7e42fd876391cb385dfa640d2ec9a8fd --- /dev/null +++ b/models/recall/gru4rec/data_prepare.sh @@ -0,0 +1,30 @@ +#! /bin/bash + +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +echo "begin to download data" +cd data && python download.py + +python preprocess.py +echo "begin to convert data (binary -> txt)" +python convert_format.py + +mkdir raw_train_data && mkdir raw_test_data +mv rsc15_train_tr_paddle.txt raw_train_data/ && mv rsc15_test_paddle.txt raw_test_data/ + +mkdir all_train && mkdir all_test +python text2paddle.py raw_train_data/ raw_test_data/ all_train all_test vocab.txt diff --git a/models/recall/gru4rec/model.py b/models/recall/gru4rec/model.py index be12ad0f0b010b33789359592afbe8a5cfe42add..d8ebd3bded85f860ba6fa287f0ea41ef55c7168a 100644 --- a/models/recall/gru4rec/model.py +++ b/models/recall/gru4rec/model.py @@ -16,6 +16,7 @@ import paddle.fluid as fluid from paddlerec.core.utils import envs from paddlerec.core.model import ModelBase +from paddlerec.core.metrics import RecallK class Model(ModelBase): @@ -81,13 +82,13 @@ class Model(ModelBase): high=self.init_high_bound), learning_rate=self.fc_lr_x)) cost = fluid.layers.cross_entropy(input=fc, label=dst_wordseq) - acc = fluid.layers.accuracy( - input=fc, label=dst_wordseq, k=self.recall_k) + acc = RecallK(input=fc, label=dst_wordseq, k=self.recall_k) + if is_infer: - self._infer_results['recall20'] = acc + self._infer_results['Recall@20'] = acc return avg_cost = fluid.layers.mean(x=cost) self._cost = avg_cost self._metrics["cost"] = avg_cost - self._metrics["acc"] = acc + self._metrics["Recall@20"] = acc