提交 93e5453c 编写于 作者: F frankwhzhang

fix gru4rec

上级 66be4d32
......@@ -12,47 +12,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.
evaluate:
reader:
batch_size: 1
class: "{workspace}/rsc15_infer_reader.py"
test_data_path: "{workspace}/data/train"
is_return_numpy: False
workspace: "paddlerec.models.recall.gru4rec"
dataset:
- name: dataset_train
batch_size: 5
type: QueueDataset
data_path: "{workspace}/data/train"
data_converter: "{workspace}/rsc15_reader.py"
- name: dataset_infer
batch_size: 5
type: QueueDataset
data_path: "{workspace}/data/test"
data_converter: "{workspace}/rsc15_reader.py"
train:
trainer:
# for cluster training
strategy: "async"
hyper_parameters:
vocab_size: 1000
hid_size: 100
emb_lr_x: 10.0
gru_lr_x: 1.0
fc_lr_x: 1.0
init_low_bound: -0.04
init_high_bound: 0.04
optimizer:
class: adagrad
learning_rate: 0.01
strategy: async
#use infer_runner mode and modify 'phase' below if infer
mode: train_runner
#mode: infer_runner
runner:
- name: train_runner
class: single_train
device: cpu
epochs: 3
workspace: "paddlerec.models.recall.gru4rec"
save_checkpoint_interval: 2
save_inference_interval: 4
save_checkpoint_path: "increment"
save_inference_path: "inference"
print_interval: 10
- name: infer_runner
class: single_infer
init_model_path: "increment/0"
device: cpu
epochs: 3
reader:
batch_size: 5
class: "{workspace}/rsc15_reader.py"
train_data_path: "{workspace}/data/train"
model:
models: "{workspace}/model.py"
hyper_parameters:
vocab_size: 1000
hid_size: 100
emb_lr_x: 10.0
gru_lr_x: 1.0
fc_lr_x: 1.0
init_low_bound: -0.04
init_high_bound: 0.04
learning_rate: 0.01
optimizer: adagrad
save:
increment:
dirname: "increment"
epoch_interval: 2
save_last: True
inference:
dirname: "inference"
epoch_interval: 4
save_last: True
phase:
- name: train
model: "{workspace}/model.py"
dataset_name: dataset_train
thread_num: 1
#- name: infer
# model: "{workspace}/model.py"
# dataset_name: dataset_infer
# thread_num: 1
......@@ -22,84 +22,72 @@ class Model(ModelBase):
def __init__(self, config):
ModelBase.__init__(self, config)
def all_vocab_network(self, is_infer=False):
""" network definition """
recall_k = envs.get_global_env("hyper_parameters.recall_k", None,
self._namespace)
vocab_size = envs.get_global_env("hyper_parameters.vocab_size", None,
self._namespace)
hid_size = envs.get_global_env("hyper_parameters.hid_size", None,
self._namespace)
init_low_bound = envs.get_global_env("hyper_parameters.init_low_bound",
None, self._namespace)
init_high_bound = envs.get_global_env(
"hyper_parameters.init_high_bound", None, self._namespace)
emb_lr_x = envs.get_global_env("hyper_parameters.emb_lr_x", None,
self._namespace)
gru_lr_x = envs.get_global_env("hyper_parameters.gru_lr_x", None,
self._namespace)
fc_lr_x = envs.get_global_env("hyper_parameters.fc_lr_x", None,
self._namespace)
def _init_hyper_parameters(self):
self.recall_k = envs.get_global_env("hyper_parameters.recall_k")
self.vocab_size = envs.get_global_env("hyper_parameters.vocab_size")
self.hid_size = envs.get_global_env("hyper_parameters.hid_size")
self.init_low_bound = envs.get_global_env(
"hyper_parameters.init_low_bound")
self.init_high_bound = envs.get_global_env(
"hyper_parameters.init_high_bound")
self.emb_lr_x = envs.get_global_env("hyper_parameters.emb_lr_x")
self.gru_lr_x = envs.get_global_env("hyper_parameters.gru_lr_x")
self.fc_lr_x = envs.get_global_env("hyper_parameters.fc_lr_x")
def input_data(self, is_infer=False, **kwargs):
# Input data
src_wordseq = fluid.data(
name="src_wordseq", shape=[None, 1], dtype="int64", lod_level=1)
dst_wordseq = fluid.data(
name="dst_wordseq", shape=[None, 1], dtype="int64", lod_level=1)
if is_infer:
self._infer_data_var = [src_wordseq, dst_wordseq]
self._infer_data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._infer_data_var,
capacity=64,
use_double_buffer=False,
iterable=False)
return [src_wordseq, dst_wordseq]
def net(self, inputs, is_infer=False):
src_wordseq = inputs[0]
dst_wordseq = inputs[1]
emb = fluid.embedding(
input=src_wordseq,
size=[vocab_size, hid_size],
size=[self.vocab_size, self.hid_size],
param_attr=fluid.ParamAttr(
name="emb",
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=emb_lr_x),
low=self.init_low_bound, high=self.init_high_bound),
learning_rate=self.emb_lr_x),
is_sparse=True)
fc0 = fluid.layers.fc(input=emb,
size=hid_size * 3,
size=self.hid_size * 3,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound,
high=init_high_bound),
learning_rate=gru_lr_x))
low=self.init_low_bound,
high=self.init_high_bound),
learning_rate=self.gru_lr_x))
gru_h0 = fluid.layers.dynamic_gru(
input=fc0,
size=hid_size,
size=self.hid_size,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=gru_lr_x))
low=self.init_low_bound, high=self.init_high_bound),
learning_rate=self.gru_lr_x))
fc = fluid.layers.fc(input=gru_h0,
size=vocab_size,
size=self.vocab_size,
act='softmax',
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound),
learning_rate=fc_lr_x))
low=self.init_low_bound,
high=self.init_high_bound),
learning_rate=self.fc_lr_x))
cost = fluid.layers.cross_entropy(input=fc, label=dst_wordseq)
acc = fluid.layers.accuracy(input=fc, label=dst_wordseq, k=recall_k)
acc = fluid.layers.accuracy(
input=fc, label=dst_wordseq, k=self.recall_k)
if is_infer:
self._infer_results['recall20'] = acc
return
avg_cost = fluid.layers.mean(x=cost)
self._data_var.append(src_wordseq)
self._data_var.append(dst_wordseq)
self._cost = avg_cost
self._metrics["cost"] = avg_cost
self._metrics["acc"] = acc
def train_net(self):
self.all_vocab_network()
def infer_net(self):
self.all_vocab_network(is_infer=True)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddlerec.core.reader import Reader
class EvaluateReader(Reader):
def init(self):
pass
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
l = line.strip().split()
l = [w for w in l]
src_seq = l[:len(l) - 1]
src_seq = [int(e) for e in src_seq]
trg_seq = l[1:]
trg_seq = [int(e) for e in trg_seq]
feature_name = ["src_wordseq", "dst_wordseq"]
yield zip(feature_name, [src_seq] + [trg_seq])
return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册