提交 da88f4fb 编写于 作者: M mapingshuo

use clone(for_test=True) to replace get_inference_program for fluid version 1.0

上级 dba26527
......@@ -112,8 +112,13 @@ def train_and_evaluate(train_reader,
print("Parallel Training is not supported for now.")
sys.exit(1)
optimizer.minimize(cost)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
#optimizer.minimize(cost)
if use_cuda:
print("Using GPU")
place = fluid.CUDAPlace(0)
else:
print("Using CPU")
place = fluid.CPUPlace()
exe = fluid.Executor(place)
if global_config.use_lod_tensor:
......@@ -126,10 +131,10 @@ def train_and_evaluate(train_reader,
print("param name: %s; param shape: %s" % (param.name, param.shape))
# define inference_program
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program([cost, acc])
inference_program = fluid.default_main_program().clone(for_test=True)
optimizer.minimize(cost)
exe.run(fluid.default_startup_program())
# load emb from a numpy erray
......@@ -234,7 +239,7 @@ def main():
optimizer,
global_config,
pretrained_word_embedding,
use_cuda=True,
use_cuda=global_config.use_cuda,
parallel=False)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册