提交 93e5453c 编写于 作者: F frankwhzhang

fix gru4rec

上级 66be4d32
...@@ -12,47 +12,59 @@ ...@@ -12,47 +12,59 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
evaluate: workspace: "paddlerec.models.recall.gru4rec"
reader:
batch_size: 1
class: "{workspace}/rsc15_infer_reader.py"
test_data_path: "{workspace}/data/train"
is_return_numpy: False
dataset:
- name: dataset_train
batch_size: 5
type: QueueDataset
data_path: "{workspace}/data/train"
data_converter: "{workspace}/rsc15_reader.py"
- name: dataset_infer
batch_size: 5
type: QueueDataset
data_path: "{workspace}/data/test"
data_converter: "{workspace}/rsc15_reader.py"
train: hyper_parameters:
trainer: vocab_size: 1000
# for cluster training hid_size: 100
strategy: "async" emb_lr_x: 10.0
gru_lr_x: 1.0
fc_lr_x: 1.0
init_low_bound: -0.04
init_high_bound: 0.04
optimizer:
class: adagrad
learning_rate: 0.01
strategy: async
#use infer_runner mode and modify 'phase' below if infer
mode: train_runner
#mode: infer_runner
runner:
- name: train_runner
class: single_train
device: cpu
epochs: 3 epochs: 3
workspace: "paddlerec.models.recall.gru4rec" save_checkpoint_interval: 2
save_inference_interval: 4
save_checkpoint_path: "increment"
save_inference_path: "inference"
print_interval: 10
- name: infer_runner
class: single_infer
init_model_path: "increment/0"
device: cpu device: cpu
epochs: 3
reader: phase:
batch_size: 5 - name: train
class: "{workspace}/rsc15_reader.py" model: "{workspace}/model.py"
train_data_path: "{workspace}/data/train" dataset_name: dataset_train
thread_num: 1
model: #- name: infer
models: "{workspace}/model.py" # model: "{workspace}/model.py"
hyper_parameters: # dataset_name: dataset_infer
vocab_size: 1000 # thread_num: 1
hid_size: 100
emb_lr_x: 10.0
gru_lr_x: 1.0
fc_lr_x: 1.0
init_low_bound: -0.04
init_high_bound: 0.04
learning_rate: 0.01
optimizer: adagrad
save:
increment:
dirname: "increment"
epoch_interval: 2
save_last: True
inference:
dirname: "inference"
epoch_interval: 4
save_last: True
...@@ -22,84 +22,72 @@ class Model(ModelBase): ...@@ -22,84 +22,72 @@ class Model(ModelBase):
def __init__(self, config): def __init__(self, config):
ModelBase.__init__(self, config) ModelBase.__init__(self, config)
def all_vocab_network(self, is_infer=False): def _init_hyper_parameters(self):
""" network definition """ self.recall_k = envs.get_global_env("hyper_parameters.recall_k")
recall_k = envs.get_global_env("hyper_parameters.recall_k", None, self.vocab_size = envs.get_global_env("hyper_parameters.vocab_size")
self._namespace) self.hid_size = envs.get_global_env("hyper_parameters.hid_size")
vocab_size = envs.get_global_env("hyper_parameters.vocab_size", None, self.init_low_bound = envs.get_global_env(
self._namespace) "hyper_parameters.init_low_bound")
hid_size = envs.get_global_env("hyper_parameters.hid_size", None, self.init_high_bound = envs.get_global_env(
self._namespace) "hyper_parameters.init_high_bound")
init_low_bound = envs.get_global_env("hyper_parameters.init_low_bound", self.emb_lr_x = envs.get_global_env("hyper_parameters.emb_lr_x")
None, self._namespace) self.gru_lr_x = envs.get_global_env("hyper_parameters.gru_lr_x")
init_high_bound = envs.get_global_env( self.fc_lr_x = envs.get_global_env("hyper_parameters.fc_lr_x")
"hyper_parameters.init_high_bound", None, self._namespace)
emb_lr_x = envs.get_global_env("hyper_parameters.emb_lr_x", None, def input_data(self, is_infer=False, **kwargs):
self._namespace)
gru_lr_x = envs.get_global_env("hyper_parameters.gru_lr_x", None,
self._namespace)
fc_lr_x = envs.get_global_env("hyper_parameters.fc_lr_x", None,
self._namespace)
# Input data # Input data
src_wordseq = fluid.data( src_wordseq = fluid.data(
name="src_wordseq", shape=[None, 1], dtype="int64", lod_level=1) name="src_wordseq", shape=[None, 1], dtype="int64", lod_level=1)
dst_wordseq = fluid.data( dst_wordseq = fluid.data(
name="dst_wordseq", shape=[None, 1], dtype="int64", lod_level=1) name="dst_wordseq", shape=[None, 1], dtype="int64", lod_level=1)
if is_infer: return [src_wordseq, dst_wordseq]
self._infer_data_var = [src_wordseq, dst_wordseq]
self._infer_data_loader = fluid.io.DataLoader.from_generator( def net(self, inputs, is_infer=False):
feed_list=self._infer_data_var, src_wordseq = inputs[0]
capacity=64, dst_wordseq = inputs[1]
use_double_buffer=False,
iterable=False)
emb = fluid.embedding( emb = fluid.embedding(
input=src_wordseq, input=src_wordseq,
size=[vocab_size, hid_size], size=[self.vocab_size, self.hid_size],
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name="emb", name="emb",
initializer=fluid.initializer.Uniform( initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound), low=self.init_low_bound, high=self.init_high_bound),
learning_rate=emb_lr_x), learning_rate=self.emb_lr_x),
is_sparse=True) is_sparse=True)
fc0 = fluid.layers.fc(input=emb, fc0 = fluid.layers.fc(input=emb,
size=hid_size * 3, size=self.hid_size * 3,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform( initializer=fluid.initializer.Uniform(
low=init_low_bound, low=self.init_low_bound,
high=init_high_bound), high=self.init_high_bound),
learning_rate=gru_lr_x)) learning_rate=self.gru_lr_x))
gru_h0 = fluid.layers.dynamic_gru( gru_h0 = fluid.layers.dynamic_gru(
input=fc0, input=fc0,
size=hid_size, size=self.hid_size,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform( initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound), low=self.init_low_bound, high=self.init_high_bound),
learning_rate=gru_lr_x)) learning_rate=self.gru_lr_x))
fc = fluid.layers.fc(input=gru_h0, fc = fluid.layers.fc(input=gru_h0,
size=vocab_size, size=self.vocab_size,
act='softmax', act='softmax',
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform( initializer=fluid.initializer.Uniform(
low=init_low_bound, high=init_high_bound), low=self.init_low_bound,
learning_rate=fc_lr_x)) high=self.init_high_bound),
learning_rate=self.fc_lr_x))
cost = fluid.layers.cross_entropy(input=fc, label=dst_wordseq) cost = fluid.layers.cross_entropy(input=fc, label=dst_wordseq)
acc = fluid.layers.accuracy(input=fc, label=dst_wordseq, k=recall_k) acc = fluid.layers.accuracy(
input=fc, label=dst_wordseq, k=self.recall_k)
if is_infer: if is_infer:
self._infer_results['recall20'] = acc self._infer_results['recall20'] = acc
return return
avg_cost = fluid.layers.mean(x=cost) avg_cost = fluid.layers.mean(x=cost)
self._data_var.append(src_wordseq)
self._data_var.append(dst_wordseq)
self._cost = avg_cost self._cost = avg_cost
self._metrics["cost"] = avg_cost self._metrics["cost"] = avg_cost
self._metrics["acc"] = acc self._metrics["acc"] = acc
def train_net(self):
self.all_vocab_network()
def infer_net(self):
self.all_vocab_network(is_infer=True)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddlerec.core.reader import Reader
class EvaluateReader(Reader):
def init(self):
pass
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
l = line.strip().split()
l = [w for w in l]
src_seq = l[:len(l) - 1]
src_seq = [int(e) for e in src_seq]
trg_seq = l[1:]
trg_seq = [int(e) for e in trg_seq]
feature_name = ["src_wordseq", "dst_wordseq"]
yield zip(feature_name, [src_seq] + [trg_seq])
return reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册