提交 4a4afde9 编写于 作者: P Peng Li

fix bug in saving model (invoke the wrong function)

上级 d53d098a
......@@ -95,7 +95,7 @@ def train():
args.model_output_prefix, event.pass_id, event.batch_id,
result.cost)
with gzip.open(path, 'w') as f:
params.to_tar(f)
trainer.save_parameter_to_tar(f)
trainer.train(
reader=paddle.batch(
......
......@@ -111,7 +111,7 @@ class DeepSpeech2Model(object):
output_model_path = os.path.join(output_model_dir,
"params.latest.tar.gz")
with gzip.open(output_model_path, 'w') as f:
self._parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
print("\nPass: %d, Batch: %d, TrainCost: %f" %
(event.pass_id, event.batch_id + 1,
cost_sum / cost_counter))
......@@ -136,7 +136,7 @@ class DeepSpeech2Model(object):
output_model_path = os.path.join(
output_model_dir, "params.pass-%d.tar.gz" % event.pass_id)
with gzip.open(output_model_path, 'w') as f:
self._parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
# run train
trainer.train(
......
......@@ -237,7 +237,7 @@ def train(train_data_path=None,
with open("%sdssm_%s_pass_%05d.tar" %
(args.model_output_prefix, model_desc,
event.pass_id), "w") as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
trainer.train(
reader=train_reader,
......
......@@ -12,9 +12,9 @@ logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO)
def save_model(save_path, parameters):
def save_model(trainer, save_path, parameters):
with gzip.open(save_path, "w") as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
def load_initial_model(model_path, parameters):
......@@ -111,7 +111,7 @@ def train(num_passes,
save_path = os.path.join(save_dir_path,
"pass_%05d_batch_%05d.tar.gz" %
(event.pass_id, event.batch_id))
save_model(save_path, parameters)
save_model(trainer, save_path, parameters)
if not event.batch_id % 5:
logger.info("Pass %d, Batch %d, Cost %f, %s" % (
......@@ -120,7 +120,7 @@ def train(num_passes,
if isinstance(event, paddle.event.EndPass):
save_path = os.path.join(save_dir_path,
"pass_%05d.tar.gz" % event.pass_id)
save_model(save_path, parameters)
save_model(trainer, save_path, parameters)
# start training
trainer.train(
......
......@@ -60,7 +60,7 @@ def train(topology,
"rnn_lm_pass_%05d_batch_%03d.tar.gz" %
(event.pass_id, event.batch_id))
with gzip.open(save_name, "w") as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
if isinstance(event, paddle.event.EndPass):
if test_reader is not None:
......@@ -70,7 +70,7 @@ def train(topology,
save_name = os.path.join(model_save_dir, "rnn_lm_pass_%05d.tar.gz" %
(event.pass_id))
with gzip.open(save_name, "w") as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
logger.info("start training...")
trainer.train(
......
......@@ -43,7 +43,7 @@ def load_pretrained_parameters(path):
return np.load(path)
def save_model(save_path, parameters):
def save_model(trainer, save_path, parameters):
""" Save the trained parameters.
Arguments:
......@@ -51,7 +51,7 @@ def save_model(save_path, parameters):
- parameters: The trained model parameters.
"""
with gzip.open(save_path, "w") as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
def show_parameter_init_info(parameters):
......@@ -161,7 +161,7 @@ def build_event_handler(config, parameters, trainer):
if event.batch_id and not event.batch_id % config.checkpoint_period:
save_path = os.path.join(config.save_dir,
"checkpoint_param.latest.tar.gz")
save_model(save_path, parameters)
save_model(trainer, save_path, parameters)
if not event.batch_id % config.log_period:
logger.info("Pass %d, Batch %d, Cost %f" %
......@@ -170,7 +170,7 @@ def build_event_handler(config, parameters, trainer):
if isinstance(event, paddle.event.EndPass):
save_path = os.path.join(config.save_dir,
"pass_%05d.tar.gz" % event.pass_id)
save_model(save_path, parameters)
save_model(trainer, save_path, parameters)
return event_handler
......
......@@ -18,13 +18,19 @@ def main(save_dir="models"):
dict_size = len(word_dict)
cost = ngram_lm(hidden_size=256, embed_size=32, dict_size=dict_size)
parameters = paddle.parameters.create(cost)
adam_optimizer = paddle.optimizer.Adam(
learning_rate=3e-3,
regularization=paddle.optimizer.L2Regularization(8e-4))
trainer = paddle.trainer.SGD(cost, parameters, adam_optimizer)
def event_handler(event):
if isinstance(event, paddle.event.EndPass):
model_name = os.path.join(save_dir, "hsigmoid_pass_%05d.tar.gz" %
event.pass_id)
logger.info("Save model into %s ..." % model_name)
with gzip.open(model_name, "w") as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
if isinstance(event, paddle.event.EndIteration):
if event.batch_id and event.batch_id % 10 == 0:
......@@ -35,12 +41,6 @@ def main(save_dir="models"):
"Pass %d, Batch %d, Cost %f, Test Cost %f" %
(event.pass_id, event.batch_id, event.cost, result.cost))
parameters = paddle.parameters.create(cost)
adam_optimizer = paddle.optimizer.Adam(
learning_rate=3e-3,
regularization=paddle.optimizer.L2Regularization(8e-4))
trainer = paddle.trainer.SGD(cost, parameters, adam_optimizer)
trainer.train(
paddle.batch(
paddle.reader.shuffle(
......
......@@ -81,6 +81,13 @@ def main():
# reader.test_reader('val.list'),
batch_size=BATCH_SIZE)
# Create trainer
trainer = paddle.trainer.SGD(
cost=cost,
parameters=parameters,
update_equation=optimizer,
extra_layers=extra_layers)
# End batch and end pass event handler
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
......@@ -89,21 +96,14 @@ def main():
event.pass_id, event.batch_id, event.cost, event.metrics)
if isinstance(event, paddle.event.EndPass):
with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
result = trainer.test(reader=test_reader)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
# Create trainer
trainer = paddle.trainer.SGD(
cost=cost,
parameters=parameters,
update_equation=optimizer,
extra_layers=extra_layers)
trainer.train(
reader=train_reader, num_passes=200, event_handler=event_handler)
if __name__ == '__main__':
main()
\ No newline at end of file
main()
......@@ -82,7 +82,7 @@ def train_lambda_rank(num_passes):
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
with gzip.open("lambda_rank_params_%d.tar.gz" % (event.pass_id),
"w") as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
feeding = {"label": 0, "data": 1}
trainer.train(
......
......@@ -86,7 +86,7 @@ def train_ranknet(num_passes):
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
with gzip.open("ranknet_params_%d.tar.gz" % (event.pass_id),
"w") as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
trainer.train(
reader=train_reader,
......
......@@ -129,7 +129,7 @@ def train():
print "Pass: %d, Batch: %d, TrainCost: %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
with gzip.open("checkpoints/params.latest.tar.gz", 'w') as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
else:
sys.stdout.write('.')
sys.stdout.flush()
......@@ -139,7 +139,7 @@ def train():
result.metrics)
with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id,
'w') as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
# run train
if not os.path.exists('checkpoints'):
......
......@@ -38,7 +38,7 @@ def train(model_save_dir):
"model_pass_%05d.tar.gz" % event.pass_id)
logger.info("Save model into %s ..." % save_path)
with gzip.open(save_path, "w") as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
trainer.train(
paddle.batch(
......
......@@ -179,7 +179,7 @@ def train(train_data_dir, test_data_dir, word_dict_path, label_dict_path,
with gzip.open(
os.path.join(model_save_dir, "params_pass_%05d.tar.gz" %
event.pass_id), "w") as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
# begin training network
trainer.train(
......
......@@ -50,7 +50,7 @@ def train(save_dir_path, source_dict_dim, target_dict_dim):
os.path.join(save_path,
"nmt_without_att_%05d_batch_%05d.tar.gz" %
event.pass_id, event.batch_id), "w") as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
if event.batch_id and not event.batch_id % 10:
logger.info("Pass %d, Batch %d, Cost %f, %s" % (
......
......@@ -264,7 +264,7 @@ def main():
# save parameters
with gzip.open('params_pass_%d.tar.gz' % event.pass_id,
'w') as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
# start to train
trainer.train(
......
......@@ -87,7 +87,7 @@ def main(train_data_file,
with gzip.open(
os.path.join(model_save_dir, "params_pass_%d.tar.gz" %
event.pass_id), "w") as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
result = trainer.test(reader=test_reader, feeding=feeding)
logger.info("\nTest with Pass %d, %s" % (event.pass_id,
......
......@@ -55,7 +55,7 @@ def train(train_file_list, dev_file_list, data_args, init_model_path):
if isinstance(event, paddle.event.EndPass):
with gzip.open('checkpoints/params_pass_%05d.tar.gz' % \
event.pass_id, 'w') as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
result = trainer.test(reader=dev_reader, feeding=feeding)
print "\nTest with Pass %d, TestCost: %f, Detection mAP=%g" % \
(event.pass_id,
......
......@@ -141,7 +141,7 @@ def train(topology,
with gzip.open(
os.path.join(model_save_dir, "dnn_params_pass_%05d.tar.gz" %
event.pass_id), "w") as f:
parameters.to_tar(f)
trainer.save_parameter_to_tar(f)
trainer.train(
reader=train_reader,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册