提交 4a4afde9 编写于 作者: P Peng Li

fix bug in saving model (invoke the wrong function)

上级 d53d098a
...@@ -95,7 +95,7 @@ def train(): ...@@ -95,7 +95,7 @@ def train():
args.model_output_prefix, event.pass_id, event.batch_id, args.model_output_prefix, event.pass_id, event.batch_id,
result.cost) result.cost)
with gzip.open(path, 'w') as f: with gzip.open(path, 'w') as f:
params.to_tar(f) trainer.save_parameter_to_tar(f)
trainer.train( trainer.train(
reader=paddle.batch( reader=paddle.batch(
......
...@@ -111,7 +111,7 @@ class DeepSpeech2Model(object): ...@@ -111,7 +111,7 @@ class DeepSpeech2Model(object):
output_model_path = os.path.join(output_model_dir, output_model_path = os.path.join(output_model_dir,
"params.latest.tar.gz") "params.latest.tar.gz")
with gzip.open(output_model_path, 'w') as f: with gzip.open(output_model_path, 'w') as f:
self._parameters.to_tar(f) trainer.save_parameter_to_tar(f)
print("\nPass: %d, Batch: %d, TrainCost: %f" % print("\nPass: %d, Batch: %d, TrainCost: %f" %
(event.pass_id, event.batch_id + 1, (event.pass_id, event.batch_id + 1,
cost_sum / cost_counter)) cost_sum / cost_counter))
...@@ -136,7 +136,7 @@ class DeepSpeech2Model(object): ...@@ -136,7 +136,7 @@ class DeepSpeech2Model(object):
output_model_path = os.path.join( output_model_path = os.path.join(
output_model_dir, "params.pass-%d.tar.gz" % event.pass_id) output_model_dir, "params.pass-%d.tar.gz" % event.pass_id)
with gzip.open(output_model_path, 'w') as f: with gzip.open(output_model_path, 'w') as f:
self._parameters.to_tar(f) trainer.save_parameter_to_tar(f)
# run train # run train
trainer.train( trainer.train(
......
...@@ -237,7 +237,7 @@ def train(train_data_path=None, ...@@ -237,7 +237,7 @@ def train(train_data_path=None,
with open("%sdssm_%s_pass_%05d.tar" % with open("%sdssm_%s_pass_%05d.tar" %
(args.model_output_prefix, model_desc, (args.model_output_prefix, model_desc,
event.pass_id), "w") as f: event.pass_id), "w") as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
trainer.train( trainer.train(
reader=train_reader, reader=train_reader,
......
...@@ -12,9 +12,9 @@ logger = logging.getLogger("paddle") ...@@ -12,9 +12,9 @@ logger = logging.getLogger("paddle")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
def save_model(save_path, parameters): def save_model(trainer, save_path, parameters):
with gzip.open(save_path, "w") as f: with gzip.open(save_path, "w") as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
def load_initial_model(model_path, parameters): def load_initial_model(model_path, parameters):
...@@ -111,7 +111,7 @@ def train(num_passes, ...@@ -111,7 +111,7 @@ def train(num_passes,
save_path = os.path.join(save_dir_path, save_path = os.path.join(save_dir_path,
"pass_%05d_batch_%05d.tar.gz" % "pass_%05d_batch_%05d.tar.gz" %
(event.pass_id, event.batch_id)) (event.pass_id, event.batch_id))
save_model(save_path, parameters) save_model(trainer, save_path, parameters)
if not event.batch_id % 5: if not event.batch_id % 5:
logger.info("Pass %d, Batch %d, Cost %f, %s" % ( logger.info("Pass %d, Batch %d, Cost %f, %s" % (
...@@ -120,7 +120,7 @@ def train(num_passes, ...@@ -120,7 +120,7 @@ def train(num_passes,
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
save_path = os.path.join(save_dir_path, save_path = os.path.join(save_dir_path,
"pass_%05d.tar.gz" % event.pass_id) "pass_%05d.tar.gz" % event.pass_id)
save_model(save_path, parameters) save_model(trainer, save_path, parameters)
# start training # start training
trainer.train( trainer.train(
......
...@@ -60,7 +60,7 @@ def train(topology, ...@@ -60,7 +60,7 @@ def train(topology,
"rnn_lm_pass_%05d_batch_%03d.tar.gz" % "rnn_lm_pass_%05d_batch_%03d.tar.gz" %
(event.pass_id, event.batch_id)) (event.pass_id, event.batch_id))
with gzip.open(save_name, "w") as f: with gzip.open(save_name, "w") as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
if test_reader is not None: if test_reader is not None:
...@@ -70,7 +70,7 @@ def train(topology, ...@@ -70,7 +70,7 @@ def train(topology,
save_name = os.path.join(model_save_dir, "rnn_lm_pass_%05d.tar.gz" % save_name = os.path.join(model_save_dir, "rnn_lm_pass_%05d.tar.gz" %
(event.pass_id)) (event.pass_id))
with gzip.open(save_name, "w") as f: with gzip.open(save_name, "w") as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
logger.info("start training...") logger.info("start training...")
trainer.train( trainer.train(
......
...@@ -43,7 +43,7 @@ def load_pretrained_parameters(path): ...@@ -43,7 +43,7 @@ def load_pretrained_parameters(path):
return np.load(path) return np.load(path)
def save_model(save_path, parameters): def save_model(trainer, save_path, parameters):
""" Save the trained parameters. """ Save the trained parameters.
Arguments: Arguments:
...@@ -51,7 +51,7 @@ def save_model(save_path, parameters): ...@@ -51,7 +51,7 @@ def save_model(save_path, parameters):
- parameters: The trained model parameters. - parameters: The trained model parameters.
""" """
with gzip.open(save_path, "w") as f: with gzip.open(save_path, "w") as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
def show_parameter_init_info(parameters): def show_parameter_init_info(parameters):
...@@ -161,7 +161,7 @@ def build_event_handler(config, parameters, trainer): ...@@ -161,7 +161,7 @@ def build_event_handler(config, parameters, trainer):
if event.batch_id and not event.batch_id % config.checkpoint_period: if event.batch_id and not event.batch_id % config.checkpoint_period:
save_path = os.path.join(config.save_dir, save_path = os.path.join(config.save_dir,
"checkpoint_param.latest.tar.gz") "checkpoint_param.latest.tar.gz")
save_model(save_path, parameters) save_model(trainer, save_path, parameters)
if not event.batch_id % config.log_period: if not event.batch_id % config.log_period:
logger.info("Pass %d, Batch %d, Cost %f" % logger.info("Pass %d, Batch %d, Cost %f" %
...@@ -170,7 +170,7 @@ def build_event_handler(config, parameters, trainer): ...@@ -170,7 +170,7 @@ def build_event_handler(config, parameters, trainer):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
save_path = os.path.join(config.save_dir, save_path = os.path.join(config.save_dir,
"pass_%05d.tar.gz" % event.pass_id) "pass_%05d.tar.gz" % event.pass_id)
save_model(save_path, parameters) save_model(trainer, save_path, parameters)
return event_handler return event_handler
......
...@@ -18,13 +18,19 @@ def main(save_dir="models"): ...@@ -18,13 +18,19 @@ def main(save_dir="models"):
dict_size = len(word_dict) dict_size = len(word_dict)
cost = ngram_lm(hidden_size=256, embed_size=32, dict_size=dict_size) cost = ngram_lm(hidden_size=256, embed_size=32, dict_size=dict_size)
parameters = paddle.parameters.create(cost)
adam_optimizer = paddle.optimizer.Adam(
learning_rate=3e-3,
regularization=paddle.optimizer.L2Regularization(8e-4))
trainer = paddle.trainer.SGD(cost, parameters, adam_optimizer)
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
model_name = os.path.join(save_dir, "hsigmoid_pass_%05d.tar.gz" % model_name = os.path.join(save_dir, "hsigmoid_pass_%05d.tar.gz" %
event.pass_id) event.pass_id)
logger.info("Save model into %s ..." % model_name) logger.info("Save model into %s ..." % model_name)
with gzip.open(model_name, "w") as f: with gzip.open(model_name, "w") as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
if event.batch_id and event.batch_id % 10 == 0: if event.batch_id and event.batch_id % 10 == 0:
...@@ -35,12 +41,6 @@ def main(save_dir="models"): ...@@ -35,12 +41,6 @@ def main(save_dir="models"):
"Pass %d, Batch %d, Cost %f, Test Cost %f" % "Pass %d, Batch %d, Cost %f, Test Cost %f" %
(event.pass_id, event.batch_id, event.cost, result.cost)) (event.pass_id, event.batch_id, event.cost, result.cost))
parameters = paddle.parameters.create(cost)
adam_optimizer = paddle.optimizer.Adam(
learning_rate=3e-3,
regularization=paddle.optimizer.L2Regularization(8e-4))
trainer = paddle.trainer.SGD(cost, parameters, adam_optimizer)
trainer.train( trainer.train(
paddle.batch( paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
......
...@@ -81,6 +81,13 @@ def main(): ...@@ -81,6 +81,13 @@ def main():
# reader.test_reader('val.list'), # reader.test_reader('val.list'),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
# Create trainer
trainer = paddle.trainer.SGD(
cost=cost,
parameters=parameters,
update_equation=optimizer,
extra_layers=extra_layers)
# End batch and end pass event handler # End batch and end pass event handler
def event_handler(event): def event_handler(event):
if isinstance(event, paddle.event.EndIteration): if isinstance(event, paddle.event.EndIteration):
...@@ -89,18 +96,11 @@ def main(): ...@@ -89,18 +96,11 @@ def main():
event.pass_id, event.batch_id, event.cost, event.metrics) event.pass_id, event.batch_id, event.cost, event.metrics)
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f: with gzip.open('params_pass_%d.tar.gz' % event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=test_reader) result = trainer.test(reader=test_reader)
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
# Create trainer
trainer = paddle.trainer.SGD(
cost=cost,
parameters=parameters,
update_equation=optimizer,
extra_layers=extra_layers)
trainer.train( trainer.train(
reader=train_reader, num_passes=200, event_handler=event_handler) reader=train_reader, num_passes=200, event_handler=event_handler)
......
...@@ -82,7 +82,7 @@ def train_lambda_rank(num_passes): ...@@ -82,7 +82,7 @@ def train_lambda_rank(num_passes):
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
with gzip.open("lambda_rank_params_%d.tar.gz" % (event.pass_id), with gzip.open("lambda_rank_params_%d.tar.gz" % (event.pass_id),
"w") as f: "w") as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
feeding = {"label": 0, "data": 1} feeding = {"label": 0, "data": 1}
trainer.train( trainer.train(
......
...@@ -86,7 +86,7 @@ def train_ranknet(num_passes): ...@@ -86,7 +86,7 @@ def train_ranknet(num_passes):
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
with gzip.open("ranknet_params_%d.tar.gz" % (event.pass_id), with gzip.open("ranknet_params_%d.tar.gz" % (event.pass_id),
"w") as f: "w") as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
trainer.train( trainer.train(
reader=train_reader, reader=train_reader,
......
...@@ -129,7 +129,7 @@ def train(): ...@@ -129,7 +129,7 @@ def train():
print "Pass: %d, Batch: %d, TrainCost: %f, %s" % ( print "Pass: %d, Batch: %d, TrainCost: %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics) event.pass_id, event.batch_id, event.cost, event.metrics)
with gzip.open("checkpoints/params.latest.tar.gz", 'w') as f: with gzip.open("checkpoints/params.latest.tar.gz", 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
else: else:
sys.stdout.write('.') sys.stdout.write('.')
sys.stdout.flush() sys.stdout.flush()
...@@ -139,7 +139,7 @@ def train(): ...@@ -139,7 +139,7 @@ def train():
result.metrics) result.metrics)
with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id, with gzip.open("checkpoints/params.pass-%d.tar.gz" % event.pass_id,
'w') as f: 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
# run train # run train
if not os.path.exists('checkpoints'): if not os.path.exists('checkpoints'):
......
...@@ -38,7 +38,7 @@ def train(model_save_dir): ...@@ -38,7 +38,7 @@ def train(model_save_dir):
"model_pass_%05d.tar.gz" % event.pass_id) "model_pass_%05d.tar.gz" % event.pass_id)
logger.info("Save model into %s ..." % save_path) logger.info("Save model into %s ..." % save_path)
with gzip.open(save_path, "w") as f: with gzip.open(save_path, "w") as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
trainer.train( trainer.train(
paddle.batch( paddle.batch(
......
...@@ -179,7 +179,7 @@ def train(train_data_dir, test_data_dir, word_dict_path, label_dict_path, ...@@ -179,7 +179,7 @@ def train(train_data_dir, test_data_dir, word_dict_path, label_dict_path,
with gzip.open( with gzip.open(
os.path.join(model_save_dir, "params_pass_%05d.tar.gz" % os.path.join(model_save_dir, "params_pass_%05d.tar.gz" %
event.pass_id), "w") as f: event.pass_id), "w") as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
# begin training network # begin training network
trainer.train( trainer.train(
......
...@@ -50,7 +50,7 @@ def train(save_dir_path, source_dict_dim, target_dict_dim): ...@@ -50,7 +50,7 @@ def train(save_dir_path, source_dict_dim, target_dict_dim):
os.path.join(save_path, os.path.join(save_path,
"nmt_without_att_%05d_batch_%05d.tar.gz" % "nmt_without_att_%05d_batch_%05d.tar.gz" %
event.pass_id, event.batch_id), "w") as f: event.pass_id, event.batch_id), "w") as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
if event.batch_id and not event.batch_id % 10: if event.batch_id and not event.batch_id % 10:
logger.info("Pass %d, Batch %d, Cost %f, %s" % ( logger.info("Pass %d, Batch %d, Cost %f, %s" % (
......
...@@ -264,7 +264,7 @@ def main(): ...@@ -264,7 +264,7 @@ def main():
# save parameters # save parameters
with gzip.open('params_pass_%d.tar.gz' % event.pass_id, with gzip.open('params_pass_%d.tar.gz' % event.pass_id,
'w') as f: 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
# start to train # start to train
trainer.train( trainer.train(
......
...@@ -87,7 +87,7 @@ def main(train_data_file, ...@@ -87,7 +87,7 @@ def main(train_data_file,
with gzip.open( with gzip.open(
os.path.join(model_save_dir, "params_pass_%d.tar.gz" % os.path.join(model_save_dir, "params_pass_%d.tar.gz" %
event.pass_id), "w") as f: event.pass_id), "w") as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=test_reader, feeding=feeding) result = trainer.test(reader=test_reader, feeding=feeding)
logger.info("\nTest with Pass %d, %s" % (event.pass_id, logger.info("\nTest with Pass %d, %s" % (event.pass_id,
......
...@@ -55,7 +55,7 @@ def train(train_file_list, dev_file_list, data_args, init_model_path): ...@@ -55,7 +55,7 @@ def train(train_file_list, dev_file_list, data_args, init_model_path):
if isinstance(event, paddle.event.EndPass): if isinstance(event, paddle.event.EndPass):
with gzip.open('checkpoints/params_pass_%05d.tar.gz' % \ with gzip.open('checkpoints/params_pass_%05d.tar.gz' % \
event.pass_id, 'w') as f: event.pass_id, 'w') as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
result = trainer.test(reader=dev_reader, feeding=feeding) result = trainer.test(reader=dev_reader, feeding=feeding)
print "\nTest with Pass %d, TestCost: %f, Detection mAP=%g" % \ print "\nTest with Pass %d, TestCost: %f, Detection mAP=%g" % \
(event.pass_id, (event.pass_id,
......
...@@ -141,7 +141,7 @@ def train(topology, ...@@ -141,7 +141,7 @@ def train(topology,
with gzip.open( with gzip.open(
os.path.join(model_save_dir, "dnn_params_pass_%05d.tar.gz" % os.path.join(model_save_dir, "dnn_params_pass_%05d.tar.gz" %
event.pass_id), "w") as f: event.pass_id), "w") as f:
parameters.to_tar(f) trainer.save_parameter_to_tar(f)
trainer.train( trainer.train(
reader=train_reader, reader=train_reader,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册