From 956217887fee6403caec9f4bc047237c8f5b9fcc Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Wed, 19 Apr 2017 22:08:54 +0800 Subject: [PATCH] support save parameter in trainer --- demo/word2vec/api_train_v2.py | 1 + python/paddle/v2/trainer.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+) diff --git a/demo/word2vec/api_train_v2.py b/demo/word2vec/api_train_v2.py index eb61a7250..604adba19 100644 --- a/demo/word2vec/api_train_v2.py +++ b/demo/word2vec/api_train_v2.py @@ -69,6 +69,7 @@ def main(): def event_handler(event): if isinstance(event, paddle.event.EndIteration): if event.batch_id % 100 == 0: + trainer.save_parameter("output", "batch-" + str(event.batch_id)) result = trainer.test( paddle.batch( paddle.dataset.imikolov.test(word_dict, N), 32)) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 552c6690a..028f25a04 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -1,4 +1,6 @@ import collections +import gzip +import os import py_paddle.swig_paddle as api @@ -96,6 +98,18 @@ class SGD(object): self.__gradient_machine__.prefetch(in_args) self.__parameter_updater__.getParametersRemote() + def save_parameter(self, dir_name, file_name): + if not os.path.exists(dir_name): + os.makedirs(dir_name) + param_file_name = dir_name + "/" + file_name + '.tar.gz' + assert not os.path.exists(param_file_name) + self.__parameter_updater__.catchUpWith() + self.__parameter_updater__.apply() + self.__parameter_updater__.getParametersRemote(True, True) + with gzip.open(param_file_name, 'w') as f: + self.__parameters__.to_tar(f) + self.__parameter_updater__.restore() + def train(self, reader, num_passes=1, event_handler=None, feeding=None): """ Training method. Will train num_passes of input data. -- GitLab