From 6a2776e139b7ba886daaa7ca0026f2f192b30c73 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 24 Apr 2017 16:29:32 +0800 Subject: [PATCH] save_parameter_to_tar to fd --- demo/word2vec/api_train_v2.py | 5 +++-- python/paddle/v2/trainer.py | 9 ++------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/demo/word2vec/api_train_v2.py b/demo/word2vec/api_train_v2.py index 98ade830c..a224951f4 100644 --- a/demo/word2vec/api_train_v2.py +++ b/demo/word2vec/api_train_v2.py @@ -1,3 +1,4 @@ +import gzip import math import paddle.v2 as paddle @@ -69,8 +70,8 @@ def main(): def event_handler(event): if isinstance(event, paddle.event.EndIteration): if event.batch_id % 100 == 0: - trainer.save_parameter_to_tar("output", - "batch-" + str(event.batch_id)) + with gzip.open("batch-" + str(event.batch_id), 'w') as f: + trainer.save_parameter_to_tar(f) result = trainer.test( paddle.batch( paddle.dataset.imikolov.test(word_dict, N), 32)) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 220d45952..6a83ba853 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -98,16 +98,11 @@ class SGD(object): self.__gradient_machine__.prefetch(in_args) self.__parameter_updater__.getParametersRemote() - def save_parameter_to_tar(self, dir_name, file_name): - if not os.path.exists(dir_name): - os.makedirs(dir_name) - param_file_name = dir_name + "/" + file_name + '.tar.gz' - assert not os.path.exists(param_file_name) + def save_parameter_to_tar(self, f): self.__parameter_updater__.catchUpWith() self.__parameter_updater__.apply() self.__parameter_updater__.getParametersRemote(True, True) - with gzip.open(param_file_name, 'w') as f: - self.__parameters__.to_tar(f) + self.__parameters__.to_tar(f) self.__parameter_updater__.restore() def train(self, reader, num_passes=1, event_handler=None, feeding=None): -- GitLab