提交 a77d2866 编写于 作者: Z Zeyu Chen

move remove feed and fetch op into Module.__init__()

上级 1282b844
......@@ -124,6 +124,11 @@ def train_net(train_reader,
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
sgd_optimizer.minimize(cost)
# write default main program
with open("./bow_net.backward.program_desc.prototxt", "w") as fo:
program_desc = str(fluid.default_main_program())
fo.write(program_desc)
# set place, executor, datafeeder
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
......@@ -191,7 +196,6 @@ def retrain_net(train_reader,
dict_dim = len(word_dict) + 2
module_dir = os.path.join(save_dirname, network_name)
print("module_dir", module_dir)
module = hub.Module(module_dir=module_dir)
main_program = fluid.Program()
......@@ -201,7 +205,7 @@ def retrain_net(train_reader,
fluid.framework.switch_main_program(module.get_inference_program())
# remove feed fetch operator and variable
hub.ModuleUtils.remove_feed_fetch_op(fluid.default_main_program())
# hub.ModuleUtils.remove_feed_fetch_op(fluid.default_main_program())
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
data = module.get_feed_var_by_index(0)
......@@ -221,6 +225,9 @@ def retrain_net(train_reader,
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=lr)
sgd_optimizer.minimize(cost)
with open("./prototxt/bow_net.finetune.program_desc.prototxt", "w") as fo:
program_desc = str(fluid.default_main_program())
fo.write(program_desc)
# set place, executor, datafeeder
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
......
......@@ -145,8 +145,7 @@ def retrain_net(train_reader,
fluid.framework.switch_main_program(module.get_inference_program())
# remove feed fetch operator and variable
hub.ModuleUtils.remove_feed_fetch_op(fluid.default_main_program())
# remove_feed_fetch_op(fluid.default_main_program())
# hub.ModuleUtils.remove_feed_fetch_op(fluid.default_main_program())
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
#data = fluid.default_main_program().global_block().var("words")
......
......@@ -72,6 +72,9 @@ class Module(object):
self.fetch_targets] = fluid.io.load_inference_model(
dirname=model_dir, executor=self.exe)
# remove feed fetch operator and variable
ModuleUtils.remove_feed_fetch_op(self.inference_program)
print("inference_program")
print(self.inference_program)
print("feed_target_names")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册