diff --git a/demo/word2vec/api_train_v2.py b/demo/word2vec/api_train_v2.py index 98ade830cf51e69e6b9ff4097090219c100f0cdf..a224951f4d4c86807294fa2fbcca2cfc38af915f 100644 --- a/demo/word2vec/api_train_v2.py +++ b/demo/word2vec/api_train_v2.py @@ -1,3 +1,4 @@ +import gzip import math import paddle.v2 as paddle @@ -69,8 +70,8 @@ def main(): def event_handler(event): if isinstance(event, paddle.event.EndIteration): if event.batch_id % 100 == 0: - trainer.save_parameter_to_tar("output", - "batch-" + str(event.batch_id)) + with gzip.open("batch-" + str(event.batch_id), 'w') as f: + trainer.save_parameter_to_tar(f) result = trainer.test( paddle.batch( paddle.dataset.imikolov.test(word_dict, N), 32)) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 220d459525f0dccc58a56b149f841ca059bb5977..6a83ba8533cd67821de2f0c730608575d37b0431 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -98,16 +98,11 @@ class SGD(object): self.__gradient_machine__.prefetch(in_args) self.__parameter_updater__.getParametersRemote() - def save_parameter_to_tar(self, dir_name, file_name): - if not os.path.exists(dir_name): - os.makedirs(dir_name) - param_file_name = dir_name + "/" + file_name + '.tar.gz' - assert not os.path.exists(param_file_name) + def save_parameter_to_tar(self, f): self.__parameter_updater__.catchUpWith() self.__parameter_updater__.apply() self.__parameter_updater__.getParametersRemote(True, True) - with gzip.open(param_file_name, 'w') as f: - self.__parameters__.to_tar(f) + self.__parameters__.to_tar(f) self.__parameter_updater__.restore() def train(self, reader, num_passes=1, event_handler=None, feeding=None):