在mq2007数据集上训练LR模型,train cost没有呈现下降。
Created by: lutaojian
常使用paddle实验LR模型,在mq2007数据集上做point wise模型,具体网络实现如下。训练过程中train cost没有呈现下降趋势,但test cost反馈出下降,无法确定是否已经收敛。
有两个问题求大神解答: 1)下文的网络实验是否与LR模型等价; 2)训练过程train cost不下降是否正常。
def sigmoid(input_dim): # data layer data = paddle.layer.data("data", paddle.data_type.dense_vector(input_dim))
# sigmoid
output = paddle.layer.fc(
input=data,
size=1,
act=paddle.activation.Sigmoid(),
param_attr=paddle.attr.Param(initial_std=0.01, name="output"))
return output
def lr(input_dim): # label layer label = paddle.layer.data("label", paddle.data_type.dense_vector(1))
# output layer
output = sigmoid(input_dim)
# cost layer
cost = paddle.layer.multi_binary_label_cross_entropy_cost(input=output, label=label)
return cost
def train_lr(num_passes): fill_default_train = functools.partial(paddle.dataset.mq2007.train, format="pointwise") fill_default_test = functools.partial(paddle.dataset.mq2007.test, format="pointwise") train_reader = paddle.batch(paddle.reader.shuffle(fill_default_train, buf_size=100), batch_size=100) test_reader = paddle.batch(fill_default_test, batch_size=100)
# mq2007 feature_dim = 46, dense format
feature_dim = 46
cost = lr(feature_dim)
parameters = paddle.parameters.create(cost)
trainer = paddle.trainer.SGD(
cost=cost,
parameters=parameters,
update_equation=paddle.optimizer.Adam(learning_rate=2e-4))
# Define the input data order
feeding = {"label": 0, "data": 1}
# Define end batch and end pass event handler
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
print "Train with Pass %d Batch %d Cost %.9f" % (
event.pass_id, event.batch_id, event.cost)
else:
sys.stdout.write(".")
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=test_reader, feeding=feeding)
print "\nTest with Pass %d Cost %.9f\n" % (event.pass_id, result.cost)
with gzip.open("lr_params_%d.tar.gz" % (event.pass_id), "w") as f:
parameters.to_tar(f)
trainer.train(
reader=train_reader,
event_handler=event_handler,
feeding=feeding,
num_passes=num_passes)