提交 a8f0789d 编写于 作者: L luotao1

put clone(for_test=True) before optimization phase

上级 da1c4ed7
......@@ -72,7 +72,8 @@ def cnn_model(data):
# TODO(dzhwinter) : refine the initializer and random seed settting
SIZE = 10
input_shape = conv_pool_2.shape
param_shape = [six.moves.reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE]
param_shape = [six.moves.reduce(lambda a, b: a * b, input_shape[1:], 1)
] + [SIZE]
scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5
predict = fluid.layers.fc(
......@@ -90,7 +91,8 @@ def eval_test(exe, batch_acc, batch_size_tensor, inference_program):
paddle.dataset.mnist.test(), batch_size=args.batch_size)
test_pass_acc = fluid.average.WeightedAverage()
for batch_id, data in enumerate(test_reader()):
img_data = np.array([x[0].reshape([1, 28, 28]) for x in data]).astype(DTYPE)
img_data = np.array(
[x[0].reshape([1, 28, 28]) for x in data]).astype(DTYPE)
y_data = np.array([x[1] for x in data]).astype("int64")
y_data = y_data.reshape([len(y_data), 1])
......@@ -123,10 +125,7 @@ def run_benchmark(model, args):
input=predict, label=label, total=batch_size_tensor)
# inference program
inference_program = fluid.default_main_program().clone()
with fluid.program_guard(inference_program):
inference_program = fluid.io.get_inference_program(
target_vars=[batch_acc, batch_size_tensor])
inference_program = fluid.default_main_program().clone(for_test=True)
# Optimization
opt = fluid.optimizer.AdamOptimizer(
......
......@@ -257,13 +257,8 @@ def split_data(data, num_part):
]
def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
def test_context(test_program, avg_cost, train_exe, dev_count, data_input_names,
sum_cost, token_num):
# Context to do validation.
test_program = train_progm.clone()
with fluid.program_guard(test_program):
test_program = fluid.io.get_inference_program([avg_cost])
val_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
......@@ -317,7 +312,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
token_num, predict):
token_num, predict, test_program):
# Initialize the parameters.
if TrainTaskConfig.ckpt_path:
fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path)
......@@ -362,7 +357,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
-1] + label_data_input_fields
if args.val_file_pattern is not None:
test = test_context(train_progm, avg_cost, train_exe, dev_count,
test = test_context(test_program, avg_cost, train_exe, dev_count,
data_input_names, sum_cost, token_num)
# the best cross-entropy value with label smoothing
......@@ -486,6 +481,7 @@ def train(args):
TrainTaskConfig.warmup_steps,
TrainTaskConfig.learning_rate)
test_program = fluid.default_main_program().clone(for_test=True)
if args.local:
optimizer = fluid.optimizer.Adam(
learning_rate=lr_scheduler.learning_rate,
......@@ -513,7 +509,7 @@ def train(args):
print("local start_up:")
train_loop(exe,
fluid.default_main_program(), dev_count, sum_cost, avg_cost,
lr_scheduler, token_num, predict)
lr_scheduler, token_num, predict, test_program)
else:
port = os.getenv("PADDLE_PORT", "6174")
pserver_ips = os.getenv("PADDLE_PSERVERS") # ip,ip...
......@@ -550,7 +546,7 @@ def train(args):
with open('trainer_prog.desc', 'w') as f:
f.write(str(trainer_prog))
train_loop(exe, trainer_prog, dev_count, sum_cost, avg_cost,
lr_scheduler, token_num, predict)
lr_scheduler, token_num, predict, test_program)
else:
print("environment var TRAINER_ROLE should be TRAINER os PSERVER")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册