提交 6a2776e1 编写于 作者: Q qiaolongfei

save_parameter_to_tar to fd

上级 9e9d4562
import gzip
import math
import paddle.v2 as paddle
......@@ -69,8 +70,8 @@ def main():
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0:
trainer.save_parameter_to_tar("output",
"batch-" + str(event.batch_id))
with gzip.open("batch-" + str(event.batch_id), 'w') as f:
trainer.save_parameter_to_tar(f)
result = trainer.test(
paddle.batch(
paddle.dataset.imikolov.test(word_dict, N), 32))
......
......@@ -98,16 +98,11 @@ class SGD(object):
self.__gradient_machine__.prefetch(in_args)
self.__parameter_updater__.getParametersRemote()
def save_parameter_to_tar(self, dir_name, file_name):
if not os.path.exists(dir_name):
os.makedirs(dir_name)
param_file_name = dir_name + "/" + file_name + '.tar.gz'
assert not os.path.exists(param_file_name)
def save_parameter_to_tar(self, f):
self.__parameter_updater__.catchUpWith()
self.__parameter_updater__.apply()
self.__parameter_updater__.getParametersRemote(True, True)
with gzip.open(param_file_name, 'w') as f:
self.__parameters__.to_tar(f)
self.__parameters__.to_tar(f)
self.__parameter_updater__.restore()
def train(self, reader, num_passes=1, event_handler=None, feeding=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册