提交 76bd4e8d 编写于 作者: J jingqinghe

update program_saver

上级 ab95661b
...@@ -199,7 +199,7 @@ class JobGenerator(object): ...@@ -199,7 +199,7 @@ class JobGenerator(object):
local_job.set_strategy(fl_strategy) local_job.set_strategy(fl_strategy)
local_job.save(output) local_job.save(output)
def save_program(self, program_path, loss): def save_program(self, program_path, input_list, loss):
if not os.path.exists(program_path): if not os.path.exists(program_path):
os.makedirs(program_path) os.makedirs(program_path)
main_program_str = fluid.default_main_program( main_program_str = fluid.default_main_program(
...@@ -210,6 +210,9 @@ class JobGenerator(object): ...@@ -210,6 +210,9 @@ class JobGenerator(object):
para_info = [] para_info = []
for pa in params: for pa in params:
para_info.append(pa.name) para_info.append(pa.name)
with open(program_path + '/input_names', 'w') as fout:
for input in input_list:
fout.write("%s\n" % input.name)
with open(program_path + '/para_info', 'w') as fout: with open(program_path + '/para_info', 'w') as fout:
for item in para_info: for item in para_info:
fout.write("%s\n" % item) fout.write("%s\n" % item)
...@@ -230,12 +233,19 @@ class JobGenerator(object): ...@@ -230,12 +233,19 @@ class JobGenerator(object):
with open(program_input + '/main_program', "rb") as fin: with open(program_input + '/main_program', "rb") as fin:
program_desc_str = fin.read() program_desc_str = fin.read()
new_main = fluid.Program.parse_from_string(program_desc_str) new_main = fluid.Program.parse_from_string(program_desc_str)
para_list = [] para_list = []
with open(program_input + '/para_info', 'r') as fin: with open(program_input + '/para_info', 'r') as fin:
for line in fin: for line in fin:
current_para = line[:-1] current_para = line[:-1]
para_list.append(current_para) para_list.append(current_para)
input_list = []
with open(program_input + '/input_names', 'r') as fin:
for line in fin:
current_input = line[:-1]
input_list.append(current_input)
with open(program_input + '/loss_name', 'r') as fin: with open(program_input + '/loss_name', 'r') as fin:
loss_name = fin.read() loss_name = fin.read()
...@@ -251,10 +261,6 @@ class JobGenerator(object): ...@@ -251,10 +261,6 @@ class JobGenerator(object):
for var in new_main.list_vars(): for var in new_main.list_vars():
if var.name == loss_name: if var.name == loss_name:
loss = var loss = var
if var.name == 'input':
input = var
if var.name == 'label':
label = var
with fluid.program_guard(new_main, new_startup): with fluid.program_guard(new_main, new_startup):
optimizer = fluid.optimizer.SGD(learning_rate=0.1, optimizer = fluid.optimizer.SGD(learning_rate=0.1,
parameter_list=para_list) parameter_list=para_list)
...@@ -283,7 +289,7 @@ class JobGenerator(object): ...@@ -283,7 +289,7 @@ class JobGenerator(object):
startup_program=startup_program, startup_program=startup_program,
job=local_job) job=local_job)
local_job.set_feed_names([input.name, label.name]) local_job.set_feed_names(input_list)
local_job.set_target_names([loss.name]) local_job.set_target_names([loss.name])
local_job.set_strategy(strategy) local_job.set_strategy(strategy)
local_job.save(output) local_job.save(output)
...@@ -32,4 +32,4 @@ exe.run(startup_program) ...@@ -32,4 +32,4 @@ exe.run(startup_program)
job_generator = JobGenerator() job_generator = JobGenerator()
program_path = './load_file' program_path = './load_file'
job_generator.save_program(program_path, avg_cost) job_generator.save_program(program_path, [input, label], avg_cost)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册