提交 0f4edd78 编写于 作者: L luotao1

use clone(for_test=True) replace get_inference_program

上级 78fd552b
......@@ -282,10 +282,7 @@ def main(args):
chunk_evaluator = fluid.metrics.ChunkEvaluator()
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program(
[num_infer_chunks, num_label_chunks, num_correct_chunks])
inference_program = fluid.default_main_program().clone(for_test=True)
train_reader = paddle.batch(
paddle.reader.shuffle(
......
......@@ -72,7 +72,8 @@ def cnn_model(data):
# TODO(dzhwinter) : refine the initializer and random seed settting
SIZE = 10
input_shape = conv_pool_2.shape
param_shape = [six.moves.reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE]
param_shape = [six.moves.reduce(lambda a, b: a * b, input_shape[1:], 1)
] + [SIZE]
scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5
predict = fluid.layers.fc(
......@@ -90,7 +91,8 @@ def eval_test(exe, batch_acc, batch_size_tensor, inference_program):
paddle.dataset.mnist.test(), batch_size=args.batch_size)
test_pass_acc = fluid.average.WeightedAverage()
for batch_id, data in enumerate(test_reader()):
img_data = np.array([x[0].reshape([1, 28, 28]) for x in data]).astype(DTYPE)
img_data = np.array(
[x[0].reshape([1, 28, 28]) for x in data]).astype(DTYPE)
y_data = np.array([x[1] for x in data]).astype("int64")
y_data = y_data.reshape([len(y_data), 1])
......@@ -123,10 +125,7 @@ def run_benchmark(model, args):
input=predict, label=label, total=batch_size_tensor)
# inference program
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program(
target_vars=[batch_acc, batch_size_tensor])
inference_program = fluid.default_main_program().clone(for_test=True)
# Optimization
opt = fluid.optimizer.AdamOptimizer(
......
......@@ -260,9 +260,7 @@ def split_data(data, num_part):
def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
sum_cost, token_num):
# Context to do validation.
test_program = train_progm.clone()
with fluid.program_guard(test_program):
test_program = fluid.io.get_inference_program([avg_cost])
test_program = train_progm.clone(for_test=True)
val_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册