diff --git a/demo/word2vec/api_train_v2.py b/demo/word2vec/api_train_v2.py index eb61a7250fb84846222d95b7a7722441ad31c924..604adba192ee23aff4359df03d9f5e0ce21b8cde 100644 --- a/demo/word2vec/api_train_v2.py +++ b/demo/word2vec/api_train_v2.py @@ -69,6 +69,7 @@ def main(): def event_handler(event): if isinstance(event, paddle.event.EndIteration): if event.batch_id % 100 == 0: + trainer.save_parameter("output", "batch-" + str(event.batch_id)) result = trainer.test( paddle.batch( paddle.dataset.imikolov.test(word_dict, N), 32)) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 552c6690a608f4cc247fa9d5f774e965cec1f97c..028f25a04676825a51530510abc579d48d2cbf4e 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -1,4 +1,6 @@ import collections +import gzip +import os import py_paddle.swig_paddle as api @@ -96,6 +98,18 @@ class SGD(object): self.__gradient_machine__.prefetch(in_args) self.__parameter_updater__.getParametersRemote() + def save_parameter(self, dir_name, file_name): + if not os.path.exists(dir_name): + os.makedirs(dir_name) + param_file_name = dir_name + "/" + file_name + '.tar.gz' + assert not os.path.exists(param_file_name) + self.__parameter_updater__.catchUpWith() + self.__parameter_updater__.apply() + self.__parameter_updater__.getParametersRemote(True, True) + with gzip.open(param_file_name, 'w') as f: + self.__parameters__.to_tar(f) + self.__parameter_updater__.restore() + def train(self, reader, num_passes=1, event_handler=None, feeding=None): """ Training method. Will train num_passes of input data.