提交 a9b22c5d 编写于 作者: G gongweibao

clean up

上级 c67625e5
......@@ -32,8 +32,6 @@ parser.add_argument(
default=TrainTaskConfig.learning_rate,
help="Learning rate for training.")
parser.add_argument('--num_passes', type=int, default=50, help="No. of passes.")
parser.add_argument(
'--device',
type=str,
......@@ -208,14 +206,16 @@ def main():
return np.mean(test_costs)
def train_loop(exe, trainer_prog):
ts = time.time()
for pass_id in xrange(args.pass_num):
ts = time.time()
total = 0
for batch_id, data in enumerate(train_reader()):
# The current program desc is coupled with batch_size, thus all
# mini-batches must have the same number of instances currently.
if len(data) != args.batch_size:
continue
total += len(data)
start_time = time.time()
data_input = prepare_batch_input(
data, encoder_input_data_names + decoder_input_data_names[:-1] +
......@@ -237,7 +237,7 @@ def main():
# Validate and save the model for inference.
val_cost = test(exe)
print("pass_id = %d cost = %f avg_speed = %.2f sample/s" %
(pass_id, cost_val, len(data) / (time.time() - ts)))
(pass_id, val_cost, total / (time.time() - ts)))
if args.local:
# Initialize the parameters.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册