未验证 提交 5f187850 编写于 作者: Z zhang wenhui 提交者: GitHub

Update2.0 model (#4905)

* update api 1.8

* fix paddlerec readme

* update 20 , test=develop
上级 3fad507e
[156, 51, 24, 103, 195, 35, 188, 16, 224, 173, 116, 3, 226, 11, 64, 94, 6, 70, 197, 164, 220, 77, 172, 194, 227, 12, 65, 129, 39, 38, 75, 210, 215, 36, 46, 185, 76, 222, 108, 78, 120, 71, 33, 189, 135, 97, 90, 219, 105, 205, 136, 167, 106, 29, 157, 125, 217, 121, 175, 143, 200, 45, 179, 37, 86, 140, 225, 47, 20, 228, 4, 209, 177, 178, 171, 58, 48, 118, 9, 149, 55, 192, 82, 17, 43, 54, 93, 96, 159, 216, 18, 206, 223, 104, 132, 182, 60, 109, 28, 180, 44, 166, 128, 27, 163, 141, 229, 102, 150, 7, 83, 198, 41, 191, 114, 117, 122, 161, 130, 174, 176, 160, 201, 49, 112, 69, 165, 95, 133, 92, 59, 110, 151, 203, 67, 169, 21, 66, 80, 22, 23, 152, 40, 127, 111, 186, 72, 26, 190, 42, 0, 63, 53, 124, 137, 85, 126, 196, 187, 208, 98, 25, 15, 170, 193, 168, 202, 31, 146, 147, 113, 32, 204, 131, 68, 84, 213, 19, 81, 79, 162, 199, 107, 50, 2, 207, 10, 181, 144, 139, 134, 62, 155, 142, 214, 212, 61, 52, 101, 99, 158, 145, 13, 153, 56, 184, 221]
\ No newline at end of file
import os
import shutil
import sys
LOCAL_PATH = os.path.dirname(os.path.abspath(__file__))
TOOLS_PATH = os.path.join(LOCAL_PATH, "..", "..", "tools")
sys.path.append(TOOLS_PATH)
from tools import download_file_and_uncompress, download_file
if __name__ == '__main__':
url = "https://s3-eu-west-1.amazonaws.com/kaggle-display-advertising-challenge-dataset/dac.tar.gz"
url2 = "https://paddlerec.bj.bcebos.com/deepfm%2Ffeat_dict_10.pkl2"
print("download and extract starting...")
download_file_and_uncompress(url)
if not os.path.exists("aid_data"):
os.makedirs("aid_data")
download_file(url2, "./aid_data/feat_dict_10.pkl2", True)
print("download and extract finished")
print("preprocessing...")
os.system("python preprocess.py")
print("preprocess done")
shutil.rmtree("raw_data")
print("done")
from __future__ import division
import os
import numpy
from collections import Counter
import shutil
import pickle
def get_raw_data(intput_file, raw_data, ins_per_file):
if not os.path.isdir(raw_data):
os.mkdir(raw_data)
fin = open(intput_file, 'r')
fout = open(os.path.join(raw_data, 'part-0'), 'w')
for line_idx, line in enumerate(fin):
if line_idx % ins_per_file == 0 and line_idx != 0:
fout.close()
cur_part_idx = int(line_idx / ins_per_file)
fout = open(
os.path.join(raw_data, 'part-' + str(cur_part_idx)), 'w')
fout.write(line)
fout.close()
fin.close()
def split_data(raw_data, aid_data, train_data, test_data):
split_rate_ = 0.9
dir_train_file_idx_ = os.path.join(aid_data, 'train_file_idx.txt')
filelist_ = [
os.path.join(raw_data, 'part-%d' % x)
for x in range(len(os.listdir(raw_data)))
]
if not os.path.exists(dir_train_file_idx_):
train_file_idx = list(
numpy.random.choice(
len(filelist_), int(len(filelist_) * split_rate_), False))
with open(dir_train_file_idx_, 'w') as fout:
fout.write(str(train_file_idx))
else:
with open(dir_train_file_idx_, 'r') as fin:
train_file_idx = eval(fin.read())
for idx in range(len(filelist_)):
if idx in train_file_idx:
shutil.move(filelist_[idx], train_data)
else:
shutil.move(filelist_[idx], test_data)
def get_feat_dict(intput_file, aid_data, print_freq=100000, total_ins=45000000):
freq_ = 10
dir_feat_dict_ = os.path.join(aid_data, 'feat_dict_' + str(freq_) + '.pkl2')
continuous_range_ = range(1, 14)
categorical_range_ = range(14, 40)
if not os.path.exists(dir_feat_dict_):
# print('generate a feature dict')
# Count the number of occurrences of discrete features
feat_cnt = Counter()
with open(intput_file, 'r') as fin:
for line_idx, line in enumerate(fin):
if line_idx % print_freq == 0:
print(r'generating feature dict {:.2f} %'.format((
line_idx / total_ins) * 100))
features = line.rstrip('\n').split('\t')
for idx in categorical_range_:
if features[idx] == '': continue
feat_cnt.update([features[idx]])
# Only retain discrete features with high frequency
dis_feat_set = set()
for feat, ot in feat_cnt.items():
if ot >= freq_:
dis_feat_set.add(feat)
# Create a dictionary for continuous and discrete features
feat_dict = {}
tc = 1
# Continuous features
for idx in continuous_range_:
feat_dict[idx] = tc
tc += 1
for feat in dis_feat_set:
feat_dict[feat] = tc
tc += 1
# Save dictionary
with open(dir_feat_dict_, 'wb') as fout:
pickle.dump(feat_dict, fout, protocol=2)
print('args.num_feat ', len(feat_dict) + 1)
def preprocess(input_file,
outdir,
ins_per_file,
total_ins=None,
print_freq=None):
train_data = os.path.join(outdir, "train_data")
test_data = os.path.join(outdir, "test_data")
aid_data = os.path.join(outdir, "aid_data")
raw_data = os.path.join(outdir, "raw_data")
if not os.path.isdir(train_data):
os.mkdir(train_data)
if not os.path.isdir(test_data):
os.mkdir(test_data)
if not os.path.isdir(aid_data):
os.mkdir(aid_data)
if print_freq is None:
print_freq = 10 * ins_per_file
get_raw_data(input_file, raw_data, ins_per_file)
split_data(raw_data, aid_data, train_data, test_data)
get_feat_dict(input_file, aid_data, print_freq, total_ins)
print('Done!')
if __name__ == '__main__':
preprocess('train.txt', './', 200000, 45000000)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import os import os
import unittest import unittest
import paddle.fluid as fluid import paddle
import paddle.fluid.core as core import numpy as np
from paddle.fluid.dygraph.nn import Embedding import six
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import SGDOptimizer import reader
from paddle.fluid.optimizer import AdagradOptimizer import model_check
from paddle.fluid.dygraph.base import to_variable import time
import numpy as np from args import *
import six
import sys
import reader if sys.version[0] == '2':
import model_check reload(sys)
import time sys.setdefaultencoding("utf-8")
from args import *
class SimpleGRURNN(paddle.fluid.Layer):
import sys def __init__(self,
if sys.version[0] == '2': hidden_size,
reload(sys) num_steps,
sys.setdefaultencoding("utf-8") num_layers=2,
init_scale=0.1,
dropout=None):
class SimpleGRURNN(fluid.Layer): super(SimpleGRURNN, self).__init__()
def __init__(self, self._hidden_size = hidden_size
hidden_size, self._num_layers = num_layers
num_steps, self._init_scale = init_scale
num_layers=2, self._dropout = dropout
init_scale=0.1, self._num_steps = num_steps
dropout=None):
super(SimpleGRURNN, self).__init__() self.weight_1_arr = []
self._hidden_size = hidden_size self.weight_2_arr = []
self._num_layers = num_layers self.weight_3_arr = []
self._init_scale = init_scale self.bias_1_arr = []
self._dropout = dropout self.bias_2_arr = []
self._num_steps = num_steps self.mask_array = []
self.weight_1_arr = [] for i in range(self._num_layers):
self.weight_2_arr = [] weight_1 = self.create_parameter(
self.weight_3_arr = [] attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Uniform(
self.bias_1_arr = [] low=-self._init_scale, high=self._init_scale)),
self.bias_2_arr = [] shape=[self._hidden_size * 2, self._hidden_size * 2],
self.mask_array = [] dtype="float32",
default_initializer=paddle.nn.initializer.Uniform(
for i in range(self._num_layers): low=-self._init_scale, high=self._init_scale))
weight_1 = self.create_parameter( self.weight_1_arr.append(self.add_parameter('w1_%d' % i, weight_1))
attr=fluid.ParamAttr( weight_2 = self.create_parameter(
initializer=fluid.initializer.UniformInitializer( attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Uniform(
low=-self._init_scale, high=self._init_scale)), low=-self._init_scale, high=self._init_scale)),
shape=[self._hidden_size * 2, self._hidden_size * 2], shape=[self._hidden_size, self._hidden_size],
dtype="float32", dtype="float32",
default_initializer=fluid.initializer.UniformInitializer( default_initializer=paddle.nn.initializer.Uniform(
low=-self._init_scale, high=self._init_scale)) low=-self._init_scale, high=self._init_scale))
self.weight_1_arr.append(self.add_parameter('w1_%d' % i, weight_1)) self.weight_2_arr.append(self.add_parameter('w2_%d' % i, weight_2))
weight_2 = self.create_parameter( weight_3 = self.create_parameter(
attr=fluid.ParamAttr( attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Uniform(
initializer=fluid.initializer.UniformInitializer( low=-self._init_scale, high=self._init_scale)),
low=-self._init_scale, high=self._init_scale)), shape=[self._hidden_size, self._hidden_size],
shape=[self._hidden_size, self._hidden_size], dtype="float32",
dtype="float32", default_initializer=paddle.nn.initializer.Uniform(
default_initializer=fluid.initializer.UniformInitializer( low=-self._init_scale, high=self._init_scale))
low=-self._init_scale, high=self._init_scale)) self.weight_3_arr.append(self.add_parameter('w3_%d' % i, weight_3))
self.weight_2_arr.append(self.add_parameter('w2_%d' % i, weight_2)) bias_1 = self.create_parameter(
weight_3 = self.create_parameter( attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Uniform(
attr=fluid.ParamAttr( low=-self._init_scale, high=self._init_scale)),
initializer=fluid.initializer.UniformInitializer( shape=[self._hidden_size * 2],
low=-self._init_scale, high=self._init_scale)), dtype="float32",
shape=[self._hidden_size, self._hidden_size], default_initializer=paddle.nn.initializer.Constant(0.0))
dtype="float32", self.bias_1_arr.append(self.add_parameter('b1_%d' % i, bias_1))
default_initializer=fluid.initializer.UniformInitializer( bias_2 = self.create_parameter(
low=-self._init_scale, high=self._init_scale)) attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Uniform(
self.weight_3_arr.append(self.add_parameter('w3_%d' % i, weight_3)) low=-self._init_scale, high=self._init_scale)),
bias_1 = self.create_parameter( shape=[self._hidden_size * 1],
attr=fluid.ParamAttr( dtype="float32",
initializer=fluid.initializer.UniformInitializer( default_initializer=paddle.nn.initializer.Constant(0.0))
low=-self._init_scale, high=self._init_scale)), self.bias_2_arr.append(self.add_parameter('b2_%d' % i, bias_2))
shape=[self._hidden_size * 2],
dtype="float32", def forward(self, input_embedding, init_hidden=None):
default_initializer=fluid.initializer.Constant(0.0)) hidden_array = []
self.bias_1_arr.append(self.add_parameter('b1_%d' % i, bias_1))
bias_2 = self.create_parameter( for i in range(self._num_layers):
attr=fluid.ParamAttr( hidden_array.append(init_hidden[i])
initializer=fluid.initializer.UniformInitializer(
low=-self._init_scale, high=self._init_scale)), res = []
shape=[self._hidden_size * 1], for index in range(self._num_steps):
dtype="float32", step_input = input_embedding[:, index, :]
default_initializer=fluid.initializer.Constant(0.0)) for k in range(self._num_layers):
self.bias_2_arr.append(self.add_parameter('b2_%d' % i, bias_2)) pre_hidden = hidden_array[k]
weight_1 = self.weight_1_arr[k]
def forward(self, input_embedding, init_hidden=None): weight_2 = self.weight_2_arr[k]
hidden_array = [] weight_3 = self.weight_3_arr[k]
bias_1 = self.bias_1_arr[k]
for i in range(self._num_layers): bias_2 = self.bias_2_arr[k]
hidden_array.append(init_hidden[i])
nn = paddle.concat(x=[step_input, pre_hidden], axis=1)
res = [] gate_input = paddle.matmul(x=nn, y=weight_1)
for index in range(self._num_steps): gate_input = paddle.add(x=gate_input, y=bias_1)
step_input = input_embedding[:, index, :] u, r = paddle.split(x=gate_input, num_or_sections=2, axis=-1)
for k in range(self._num_layers): hidden_c = paddle.tanh(
pre_hidden = hidden_array[k] paddle.add(x=paddle.matmul(
weight_1 = self.weight_1_arr[k] x=step_input, y=weight_2) + paddle.matmul(
weight_2 = self.weight_2_arr[k] x=(paddle.nn.functional.sigmoid(r) * pre_hidden),
weight_3 = self.weight_3_arr[k] y=weight_3),
bias_1 = self.bias_1_arr[k] y=bias_2))
bias_2 = self.bias_2_arr[k] hidden_state = paddle.nn.functional.sigmoid(u) * pre_hidden + (
1.0 - paddle.nn.functional.sigmoid(u)) * hidden_c
nn = fluid.layers.concat([step_input, pre_hidden], 1) hidden_array[k] = hidden_state
gate_input = fluid.layers.matmul(x=nn, y=weight_1) step_input = hidden_state
gate_input = fluid.layers.elementwise_add(gate_input, bias_1)
u, r = fluid.layers.split(gate_input, num_or_sections=2, dim=-1) if self._dropout is not None and self._dropout > 0.0:
hidden_c = fluid.layers.tanh( step_input = paddle.fluid.layers.dropout(
fluid.layers.elementwise_add( step_input,
fluid.layers.matmul( dropout_prob=self._dropout,
x=step_input, y=weight_2) + fluid.layers.matmul( dropout_implementation='upscale_in_train')
x=(fluid.layers.sigmoid(r) * pre_hidden), res.append(step_input)
y=weight_3), real_res = paddle.concat(x=res, axis=1)
bias_2)) real_res = paddle.fluid.layers.reshape(
hidden_state = fluid.layers.sigmoid(u) * pre_hidden + ( real_res, [-1, self._num_steps, self._hidden_size])
1.0 - fluid.layers.sigmoid(u)) * hidden_c last_hidden = paddle.concat(x=hidden_array, axis=1)
hidden_array[k] = hidden_state last_hidden = paddle.fluid.layers.reshape(
step_input = hidden_state last_hidden, shape=[-1, self._num_layers, self._hidden_size])
last_hidden = paddle.transpose(x=last_hidden, perm=[1, 0, 2])
if self._dropout is not None and self._dropout > 0.0: return real_res, last_hidden
step_input = fluid.layers.dropout(
step_input,
dropout_prob=self._dropout, class PtbModel(paddle.fluid.Layer):
dropout_implementation='upscale_in_train') def __init__(self,
res.append(step_input) name_scope,
real_res = fluid.layers.concat(res, 1) hidden_size,
real_res = fluid.layers.reshape( vocab_size,
real_res, [-1, self._num_steps, self._hidden_size]) num_layers=2,
last_hidden = fluid.layers.concat(hidden_array, 1) num_steps=20,
last_hidden = fluid.layers.reshape( init_scale=0.1,
last_hidden, shape=[-1, self._num_layers, self._hidden_size]) dropout=None):
last_hidden = fluid.layers.transpose(x=last_hidden, perm=[1, 0, 2]) #super(PtbModel, self).__init__(name_scope)
return real_res, last_hidden super(PtbModel, self).__init__()
self.hidden_size = hidden_size
self.vocab_size = vocab_size
class PtbModel(fluid.Layer): self.init_scale = init_scale
def __init__(self, self.num_layers = num_layers
name_scope, self.num_steps = num_steps
hidden_size, self.dropout = dropout
vocab_size, self.simple_gru_rnn = SimpleGRURNN(
num_layers=2, #self.full_name(),
num_steps=20, hidden_size,
init_scale=0.1, num_steps,
dropout=None): num_layers=num_layers,
#super(PtbModel, self).__init__(name_scope) init_scale=init_scale,
super(PtbModel, self).__init__() dropout=dropout)
self.hidden_size = hidden_size self.embedding = paddle.fluid.dygraph.nn.Embedding(
self.vocab_size = vocab_size #self.full_name(),
self.init_scale = init_scale size=[vocab_size, hidden_size],
self.num_layers = num_layers dtype='float32',
self.num_steps = num_steps is_sparse=False,
self.dropout = dropout param_attr=paddle.ParamAttr(
self.simple_gru_rnn = SimpleGRURNN( name='embedding_para',
#self.full_name(), initializer=paddle.nn.initializer.Uniform(
hidden_size, low=-init_scale, high=init_scale)))
num_steps, self.softmax_weight = self.create_parameter(
num_layers=num_layers, attr=paddle.ParamAttr(),
init_scale=init_scale, shape=[self.hidden_size, self.vocab_size],
dropout=dropout) dtype="float32",
self.embedding = Embedding( default_initializer=paddle.nn.initializer.Uniform(
#self.full_name(), low=-self.init_scale, high=self.init_scale))
size=[vocab_size, hidden_size], self.softmax_bias = self.create_parameter(
dtype='float32', attr=paddle.ParamAttr(),
is_sparse=False, shape=[self.vocab_size],
param_attr=fluid.ParamAttr( dtype="float32",
name='embedding_para', default_initializer=paddle.nn.initializer.Uniform(
initializer=fluid.initializer.UniformInitializer( low=-self.init_scale, high=self.init_scale))
low=-init_scale, high=init_scale)))
self.softmax_weight = self.create_parameter( def build_once(self, input, label, init_hidden):
attr=fluid.ParamAttr(), pass
shape=[self.hidden_size, self.vocab_size],
dtype="float32", def forward(self, input, label, init_hidden):
default_initializer=fluid.initializer.UniformInitializer(
low=-self.init_scale, high=self.init_scale)) init_h = paddle.fluid.layers.reshape(
self.softmax_bias = self.create_parameter( init_hidden, shape=[self.num_layers, -1, self.hidden_size])
attr=fluid.ParamAttr(),
shape=[self.vocab_size], x_emb = self.embedding(input)
dtype="float32",
default_initializer=fluid.initializer.UniformInitializer( x_emb = paddle.fluid.layers.reshape(
low=-self.init_scale, high=self.init_scale)) x_emb, shape=[-1, self.num_steps, self.hidden_size])
if self.dropout is not None and self.dropout > 0.0:
def build_once(self, input, label, init_hidden): x_emb = paddle.fluid.layers.dropout(
pass x_emb,
dropout_prob=self.dropout,
def forward(self, input, label, init_hidden): dropout_implementation='upscale_in_train')
rnn_out, last_hidden = self.simple_gru_rnn(x_emb, init_h)
init_h = fluid.layers.reshape(
init_hidden, shape=[self.num_layers, -1, self.hidden_size]) projection = paddle.matmul(x=rnn_out, y=self.softmax_weight)
projection = paddle.add(x=projection, y=self.softmax_bias)
x_emb = self.embedding(input) loss = paddle.nn.functional.softmax_with_cross_entropy(
logits=projection, label=label, soft_label=False)
x_emb = fluid.layers.reshape( pre_2d = paddle.fluid.layers.reshape(
x_emb, shape=[-1, self.num_steps, self.hidden_size]) projection, shape=[-1, self.vocab_size])
if self.dropout is not None and self.dropout > 0.0: label_2d = paddle.fluid.layers.reshape(label, shape=[-1, 1])
x_emb = fluid.layers.dropout( acc = paddle.metric.accuracy(input=pre_2d, label=label_2d, k=20)
x_emb, loss = paddle.fluid.layers.reshape(loss, shape=[-1, self.num_steps])
dropout_prob=self.dropout, loss = paddle.reduce_mean(loss, dim=[0])
dropout_implementation='upscale_in_train') loss = paddle.reduce_sum(loss)
rnn_out, last_hidden = self.simple_gru_rnn(x_emb, init_h)
return loss, last_hidden, acc
projection = fluid.layers.matmul(rnn_out, self.softmax_weight)
projection = fluid.layers.elementwise_add(projection, self.softmax_bias) def debug_emb(self):
loss = fluid.layers.softmax_with_cross_entropy(
logits=projection, label=label, soft_label=False) np.save("emb_grad", self.x_emb.gradient())
pre_2d = fluid.layers.reshape(projection, shape=[-1, self.vocab_size])
label_2d = fluid.layers.reshape(label, shape=[-1, 1])
acc = fluid.layers.accuracy(input=pre_2d, label=label_2d, k=20) def train_ptb_lm():
loss = fluid.layers.reshape(loss, shape=[-1, self.num_steps]) args = parse_args()
loss = fluid.layers.reduce_mean(loss, dim=[0])
loss = fluid.layers.reduce_sum(loss) # check if set use_gpu=True in paddlepaddle cpu version
model_check.check_cuda(args.use_gpu)
return loss, last_hidden, acc # check if paddlepaddle version is satisfied
model_check.check_version()
def debug_emb(self):
model_type = args.model_type
np.save("emb_grad", self.x_emb.gradient())
vocab_size = 37484
if model_type == "gru4rec":
def train_ptb_lm(): num_layers = 1
args = parse_args() batch_size = 500
hidden_size = 100
# check if set use_gpu=True in paddlepaddle cpu version num_steps = 10
model_check.check_cuda(args.use_gpu) init_scale = 0.1
# check if paddlepaddle version is satisfied max_grad_norm = 5.0
model_check.check_version() epoch_start_decay = 10
max_epoch = 5
model_type = args.model_type dropout = 0.0
lr_decay = 0.5
vocab_size = 37484 base_learning_rate = 0.05
if model_type == "gru4rec": else:
num_layers = 1 print("model type not support")
batch_size = 500 return
hidden_size = 100
num_steps = 10 paddle.disable_static(paddle.fluid.core.CUDAPlace(0))
init_scale = 0.1 if args.ce:
max_grad_norm = 5.0 print("ce mode")
epoch_start_decay = 10 seed = 33
max_epoch = 5 np.random.seed(seed)
dropout = 0.0 paddle.static.default_startup_program().random_seed = seed
lr_decay = 0.5 paddle.static.default_main_program().random_seed = seed
base_learning_rate = 0.05 max_epoch = 1
else: ptb_model = PtbModel(
print("model type not support") "ptb_model",
return hidden_size=hidden_size,
vocab_size=vocab_size,
with fluid.dygraph.guard(core.CUDAPlace(0)): num_layers=num_layers,
if args.ce: num_steps=num_steps,
print("ce mode") init_scale=init_scale,
seed = 33 dropout=dropout)
np.random.seed(seed)
fluid.default_startup_program().random_seed = seed if args.init_from_pretrain_model:
fluid.default_main_program().random_seed = seed if not os.path.exists(args.init_from_pretrain_model + '.pdparams'):
max_epoch = 1 print(args.init_from_pretrain_model)
ptb_model = PtbModel( raise Warning("The pretrained params do not exist.")
"ptb_model", return
hidden_size=hidden_size, paddle.fluid.load_dygraph(args.init_from_pretrain_model)
vocab_size=vocab_size, print("finish initing model from pretrained params from %s" %
num_layers=num_layers, (args.init_from_pretrain_model))
num_steps=num_steps,
init_scale=init_scale, dy_param_updated = dict()
dropout=dropout) dy_param_init = dict()
dy_loss = None
if args.init_from_pretrain_model: last_hidden = None
if not os.path.exists(args.init_from_pretrain_model + '.pdparams'):
print(args.init_from_pretrain_model) data_path = args.data_path
raise Warning("The pretrained params do not exist.") print("begin to load data")
return ptb_data = reader.get_ptb_data(data_path)
fluid.load_dygraph(args.init_from_pretrain_model) print("finished load data")
print("finish initing model from pretrained params from %s" % train_data, valid_data, test_data = ptb_data
(args.init_from_pretrain_model))
batch_len = len(train_data) // batch_size
dy_param_updated = dict() total_batch_size = (batch_len - 1) // num_steps
dy_param_init = dict() print("total_batch_size:", total_batch_size)
dy_loss = None log_interval = total_batch_size // 20
last_hidden = None
bd = []
data_path = args.data_path lr_arr = [base_learning_rate]
print("begin to load data") for i in range(1, max_epoch):
ptb_data = reader.get_ptb_data(data_path) bd.append(total_batch_size * i)
print("finished load data") new_lr = base_learning_rate * (lr_decay
train_data, valid_data, test_data = ptb_data **max(i + 1 - epoch_start_decay, 0.0))
lr_arr.append(new_lr)
batch_len = len(train_data) // batch_size
total_batch_size = (batch_len - 1) // num_steps grad_clip = paddle.nn.ClipGradByGlobalNorm(max_grad_norm)
print("total_batch_size:", total_batch_size) sgd = paddle.optimizer.Adagrad(
log_interval = total_batch_size // 20 parameters=ptb_model.parameters(),
learning_rate=base_learning_rate,
bd = [] #learning_rate=paddle.fluid.layers.piecewise_decay(
lr_arr = [base_learning_rate] # boundaries=bd, values=lr_arr),
for i in range(1, max_epoch): grad_clip=grad_clip)
bd.append(total_batch_size * i)
new_lr = base_learning_rate * (lr_decay** print("parameters:--------------------------------")
max(i + 1 - epoch_start_decay, 0.0)) for para in ptb_model.parameters():
lr_arr.append(new_lr) print(para.name)
print("parameters:--------------------------------")
grad_clip = fluid.clip.GradientClipByGlobalNorm(max_grad_norm)
sgd = AdagradOptimizer( def eval(model, data):
parameter_list=ptb_model.parameters(), print("begion to eval")
learning_rate=fluid.layers.piecewise_decay( total_loss = 0.0
boundaries=bd, values=lr_arr), iters = 0.0
grad_clip=grad_clip) init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
print("parameters:--------------------------------")
for para in ptb_model.parameters(): model.eval()
print(para.name) train_data_iter = reader.get_data_iter(data, batch_size, num_steps)
print("parameters:--------------------------------") init_hidden = paddle.to_tensor(
data=init_hidden_data, dtype=None, place=None, stop_gradient=True)
def eval(model, data): accum_num_recall = 0.0
print("begion to eval") for batch_id, batch in enumerate(train_data_iter):
total_loss = 0.0 x_data, y_data = batch
iters = 0.0 x_data = x_data.reshape((-1, num_steps, 1))
init_hidden_data = np.zeros( y_data = y_data.reshape((-1, num_steps, 1))
(num_layers, batch_size, hidden_size), dtype='float32') x = paddle.to_tensor(
data=x_data, dtype=None, place=None, stop_gradient=True)
model.eval() y = paddle.to_tensor(
train_data_iter = reader.get_data_iter(data, batch_size, num_steps) data=y_data, dtype=None, place=None, stop_gradient=True)
init_hidden = to_variable(init_hidden_data) dy_loss, last_hidden, acc = ptb_model(x, y, init_hidden)
accum_num_recall = 0.0
for batch_id, batch in enumerate(train_data_iter): out_loss = dy_loss.numpy()
x_data, y_data = batch acc_ = acc.numpy()[0]
x_data = x_data.reshape((-1, num_steps, 1)) accum_num_recall += acc_
y_data = y_data.reshape((-1, num_steps, 1)) if batch_id % 1 == 0:
x = to_variable(x_data) print("batch_id:%d recall@20:%.4f" %
y = to_variable(y_data) (batch_id, accum_num_recall / (batch_id + 1)))
dy_loss, last_hidden, acc = ptb_model(x, y, init_hidden)
init_hidden = last_hidden
out_loss = dy_loss.numpy()
acc_ = acc.numpy()[0] total_loss += out_loss
accum_num_recall += acc_ iters += num_steps
if batch_id % 1 == 0:
print("batch_id:%d recall@20:%.4f" % print("eval finished")
(batch_id, accum_num_recall / (batch_id + 1))) ppl = np.exp(total_loss / iters)
print("recall@20 ", accum_num_recall / (batch_id + 1))
init_hidden = last_hidden if args.ce:
print("kpis\ttest_ppl\t%0.3f" % ppl[0])
total_loss += out_loss
iters += num_steps for epoch_id in range(max_epoch):
ptb_model.train()
print("eval finished") total_loss = 0.0
ppl = np.exp(total_loss / iters) iters = 0.0
print("recall@20 ", accum_num_recall / (batch_id + 1)) init_hidden_data = np.zeros(
if args.ce: (num_layers, batch_size, hidden_size), dtype='float32')
print("kpis\ttest_ppl\t%0.3f" % ppl[0])
train_data_iter = reader.get_data_iter(train_data, batch_size,
for epoch_id in range(max_epoch): num_steps)
ptb_model.train() init_hidden = paddle.to_tensor(
total_loss = 0.0 data=init_hidden_data, dtype=None, place=None, stop_gradient=True)
iters = 0.0
init_hidden_data = np.zeros( start_time = time.time()
(num_layers, batch_size, hidden_size), dtype='float32') for batch_id, batch in enumerate(train_data_iter):
x_data, y_data = batch
train_data_iter = reader.get_data_iter(train_data, batch_size, x_data = x_data.reshape((-1, num_steps, 1))
num_steps) y_data = y_data.reshape((-1, num_steps, 1))
init_hidden = to_variable(init_hidden_data) x = paddle.to_tensor(
data=x_data, dtype=None, place=None, stop_gradient=True)
start_time = time.time() y = paddle.to_tensor(
for batch_id, batch in enumerate(train_data_iter): data=y_data, dtype=None, place=None, stop_gradient=True)
x_data, y_data = batch dy_loss, last_hidden, acc = ptb_model(x, y, init_hidden)
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, num_steps, 1)) out_loss = dy_loss.numpy()
x = to_variable(x_data) acc_ = acc.numpy()[0]
y = to_variable(y_data)
dy_loss, last_hidden, acc = ptb_model(x, y, init_hidden) init_hidden = last_hidden.detach()
dy_loss.backward()
out_loss = dy_loss.numpy() sgd.minimize(dy_loss)
acc_ = acc.numpy()[0] ptb_model.clear_gradients()
total_loss += out_loss
init_hidden = last_hidden.detach() iters += num_steps
dy_loss.backward()
sgd.minimize(dy_loss) if batch_id > 0 and batch_id % 100 == 1:
ptb_model.clear_gradients() ppl = np.exp(total_loss / iters)
total_loss += out_loss print(
iters += num_steps "-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, acc: %.5f, lr: %.5f"
% (epoch_id, batch_id, ppl[0], acc_,
if batch_id > 0 and batch_id % 100 == 1: sgd._global_learning_rate().numpy()))
ppl = np.exp(total_loss / iters)
print( print("one ecpoh finished", epoch_id)
"-- Epoch:[%d]; Batch:[%d]; ppl: %.5f, acc: %.5f, lr: %.5f" print("time cost ", time.time() - start_time)
% (epoch_id, batch_id, ppl[0], acc_, ppl = np.exp(total_loss / iters)
sgd._global_learning_rate().numpy())) print("-- Epoch:[%d]; ppl: %.5f" % (epoch_id, ppl[0]))
if args.ce:
print("one ecpoh finished", epoch_id) print("kpis\ttrain_ppl\t%0.3f" % ppl[0])
print("time cost ", time.time() - start_time) save_model_dir = os.path.join(args.save_model_dir,
ppl = np.exp(total_loss / iters) str(epoch_id), 'params')
print("-- Epoch:[%d]; ppl: %.5f" % (epoch_id, ppl[0])) paddle.fluid.save_dygraph(ptb_model.state_dict(), save_model_dir)
if args.ce: print("Saved model to: %s.\n" % save_model_dir)
print("kpis\ttrain_ppl\t%0.3f" % ppl[0]) eval(ptb_model, test_data)
save_model_dir = os.path.join(args.save_model_dir, paddle.enable_static()
str(epoch_id), 'params')
fluid.save_dygraph(ptb_model.state_dict(), save_model_dir) #eval(ptb_model, test_data)
print("Saved model to: %s.\n" % save_model_dir)
eval(ptb_model, test_data)
train_ptb_lm()
#eval(ptb_model, test_data)
train_ptb_lm()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册