提交 338c4ab4 编写于 作者: J jingqinghe

generate job from pre-defined program

上级 ec00301f
...@@ -214,6 +214,9 @@ class JobGenerator(object): ...@@ -214,6 +214,9 @@ class JobGenerator(object):
current_para = line[:-1] current_para = line[:-1]
para_list.append(current_para) para_list.append(current_para)
with open(program_input + '/loss_name', 'rb') as fin:
loss_name = fin.read()
for item in para_list: for item in para_list:
para = new_main.global_block().var(item) para = new_main.global_block().var(item)
para.regularizer = None para.regularizer = None
...@@ -224,7 +227,7 @@ class JobGenerator(object): ...@@ -224,7 +227,7 @@ class JobGenerator(object):
input = None input = None
label = None label = None
for var in new_main.list_vars(): for var in new_main.list_vars():
if var.name == "loss.tmp_0": if var.name == loss_name:
loss = var loss = var
if var.name == 'input': if var.name == 'input':
input = var input = var
......
...@@ -24,7 +24,6 @@ sum_cost = fluid.layers.cross_entropy(input=predict, label=label) ...@@ -24,7 +24,6 @@ sum_cost = fluid.layers.cross_entropy(input=predict, label=label)
accuracy = fluid.layers.accuracy(input=predict, label=label) accuracy = fluid.layers.accuracy(input=predict, label=label)
avg_cost = fluid.layers.mean(sum_cost, name="loss") avg_cost = fluid.layers.mean(sum_cost, name="loss")
startup_program = fluid.default_startup_program() startup_program = fluid.default_startup_program()
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_program) exe.run(startup_program)
...@@ -47,6 +46,8 @@ def save_program(program_path): ...@@ -47,6 +46,8 @@ def save_program(program_path):
fout.write(startup_program_str) fout.write(startup_program_str)
with open(program_path + '/main_program', "wb") as fout: with open(program_path + '/main_program', "wb") as fout:
fout.write(main_program_str) fout.write(main_program_str)
with open(program_path + '/loss_name', 'wb') as fout:
fout.write(avg_cost.name)
program_path = './load_file' program_path = './load_file'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册