paddle v2本地训练 模型没有信息输出
Created by: yuanxiangxie
核心训练代码如下
def train_model(self):
"""
训练模型
"""
train_data_reader = paddle.batch(paddle.reader.shuffle(self.data_reader(train_data_dir), buf_size = train_buf_size), batch_size = train_batch_size)
test_data_reader = paddle.batch(self.data_reader(test_data_dir), batch_size = test_batch_size)
(content_feature, word_feature) = self.build_model_feature()
label = paddle.layer.data(name = "label", type = paddle.data_type.dense_vector(1))
inference = paddle.layer.cos_sim(a = content_feature, b = word_feature, size = 1)
cost = paddle.layer.square_error_cost(input = (inference + 1.0) * 0.5, label = label)
parameters = paddle.parameters.create(cost)
adam_optimizer = paddle.optimizer.Adam(
learning_rate=1e-3,
regularization=paddle.optimizer.L2Regularization(rate=1e-3),
model_average=paddle.optimizer.ModelAverage(average_window=0.5, max_average_window = 10000))
trainer = paddle.trainer.SGD(
cost = cost,
extra_layers=paddle.evaluator.classification_error(input= (inference + 1.0) * 0.5, label=label),
parameters=parameters,
update_equation=adam_optimizer)
feeding = {
"content": 0,
"word": 1,
"word_len": 2,
"label": 3
}
def event_handler(event):
"""
事件监听
"""
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 10 == 0:
print >> sys.stderr, "[NOTICE] Pass:{} Batch:{} Cost:{:.2f} {}".format(event.pass_id, event.batch_id, event.cost, event.metrics)
if isinstance(event, paddle.event.EndPass):
if test_data_reader is not None:
result = trainer.test(reader = test_data_reader, feeding = feeding)
print >> sys.stderr, "[NOTICE] Test at Pass:{} {}".format(event.pass_id, result.metrics)
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
with gzip.open(os.path.join(model_save_dir, "ctr_model_{}.tar.gz".format(event.pass_id)), "w") as out_file:
trainer.save_parameter_to_tar(out_file)
print >> sys.stderr, "[NOTICE] move model to output dir ..."
shutil.move(model_save_dir, "output/model_params")
print >> sys.stderr, "[NOTICE] embedding feature building finished ..."
print >> sys.stderr, "[NOTICE] train ctr model start ..."
trainer.train(
reader = train_data_reader,
event_handler = event_handler,
feeding = feeding,
num_passes = num_passes)