提交 76bd4e8d 编写于 作者: J jingqinghe

update program_saver

上级 ab95661b
......@@ -199,7 +199,7 @@ class JobGenerator(object):
local_job.set_strategy(fl_strategy)
local_job.save(output)
def save_program(self, program_path, loss):
def save_program(self, program_path, input_list, loss):
if not os.path.exists(program_path):
os.makedirs(program_path)
main_program_str = fluid.default_main_program(
......@@ -210,6 +210,9 @@ class JobGenerator(object):
para_info = []
for pa in params:
para_info.append(pa.name)
with open(program_path + '/input_names', 'w') as fout:
for input in input_list:
fout.write("%s\n" % input.name)
with open(program_path + '/para_info', 'w') as fout:
for item in para_info:
fout.write("%s\n" % item)
......@@ -230,12 +233,19 @@ class JobGenerator(object):
with open(program_input + '/main_program', "rb") as fin:
program_desc_str = fin.read()
new_main = fluid.Program.parse_from_string(program_desc_str)
para_list = []
with open(program_input + '/para_info', 'r') as fin:
for line in fin:
current_para = line[:-1]
para_list.append(current_para)
input_list = []
with open(program_input + '/input_names', 'r') as fin:
for line in fin:
current_input = line[:-1]
input_list.append(current_input)
with open(program_input + '/loss_name', 'r') as fin:
loss_name = fin.read()
......@@ -251,10 +261,6 @@ class JobGenerator(object):
for var in new_main.list_vars():
if var.name == loss_name:
loss = var
if var.name == 'input':
input = var
if var.name == 'label':
label = var
with fluid.program_guard(new_main, new_startup):
optimizer = fluid.optimizer.SGD(learning_rate=0.1,
parameter_list=para_list)
......@@ -283,7 +289,7 @@ class JobGenerator(object):
startup_program=startup_program,
job=local_job)
local_job.set_feed_names([input.name, label.name])
local_job.set_feed_names(input_list)
local_job.set_target_names([loss.name])
local_job.set_strategy(strategy)
local_job.save(output)
......@@ -32,4 +32,4 @@ exe.run(startup_program)
job_generator = JobGenerator()
program_path = './load_file'
job_generator.save_program(program_path, avg_cost)
job_generator.save_program(program_path, [input, label], avg_cost)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册