From ed812bdbe87ae3ccd93847c6cb6fbd6fa232bd30 Mon Sep 17 00:00:00 2001 From: kolinwei <331911734@qq.com> Date: Wed, 16 May 2018 15:22:56 +0800 Subject: [PATCH] benchmark script support multi card train --- benchmark/fluid/mnist.py | 12 ++++++------ benchmark/fluid/resnet.py | 10 ++++++---- benchmark/fluid/vgg.py | 10 ++++++---- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/benchmark/fluid/mnist.py b/benchmark/fluid/mnist.py index 1e2185dfac..3574c1c0e9 100644 --- a/benchmark/fluid/mnist.py +++ b/benchmark/fluid/mnist.py @@ -159,6 +159,7 @@ def run_benchmark(model, args): paddle.dataset.mnist.train(), batch_size=args.batch_size) accuracy = fluid.metrics.Accuracy() + train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name) iters, num_samples, start_time = 0, 0, time.time() for pass_id in range(args.pass_num): accuracy.reset() @@ -175,17 +176,16 @@ def run_benchmark(model, args): y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = y_data.reshape([len(y_data), 1]) - outs = exe.run( - fluid.default_main_program(), + outs = train_exe.run( feed={"pixel": img_data, "label": y_data}, - fetch_list=[avg_cost, batch_acc, batch_size_tensor] + fetch_list=[avg_cost.name, batch_acc.name, batch_size_tensor.name] ) # The accuracy is the accumulation of batches, but not the current batch. - accuracy.update(value=outs[1], weight=outs[2]) + accuracy.update(value=np.array(np.mean(outs[1])), weight=np.mean(np.array(outs[2]))) iters += 1 num_samples += len(y_data) - loss = np.array(outs[0]) - acc = np.array(outs[1]) + loss = np.mean(np.array(outs[0])) + acc = np.mean(np.array(outs[1])) train_losses.append(loss) train_accs.append(acc) print("Pass: %d, Iter: %d, Loss: %f, Accuracy: %f" % diff --git a/benchmark/fluid/resnet.py b/benchmark/fluid/resnet.py index 831fa2c019..5f60f806f7 100644 --- a/benchmark/fluid/resnet.py +++ b/benchmark/fluid/resnet.py @@ -241,6 +241,7 @@ def run_benchmark(model, args): exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) accuracy = fluid.average.WeightedAverage() + train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name) if args.use_fake_data: data = train_reader().next() image = np.array(map(lambda x: x[0].reshape(dshape), data)).astype( @@ -264,14 +265,15 @@ def run_benchmark(model, args): data)).astype('float32') label = np.array(map(lambda x: x[1], data)).astype('int64') label = label.reshape([-1, 1]) - loss, acc, weight = exe.run( - fluid.default_main_program(), + loss, acc, weight = train_exe.run( feed={'data': image, 'label': label}, - fetch_list=[avg_cost, batch_acc, batch_size_tensor]) + fetch_list=[avg_cost.name, batch_acc.name, batch_size_tensor.name]) iters += 1 num_samples += len(label) - accuracy.add(value=acc, weight=weight) + accuracy.add(value=np.array(np.mean(acc)), weight=np.mean(weight)) + loss = np.mean(np.array(loss)) + acc = np.mean(np.array(acc)) train_losses.append(loss) train_accs.append(acc) print("Pass: %d, Iter: %d, Loss: %f, Accuracy: %f" % diff --git a/benchmark/fluid/vgg.py b/benchmark/fluid/vgg.py index 53e34e0cbd..261446e7e9 100644 --- a/benchmark/fluid/vgg.py +++ b/benchmark/fluid/vgg.py @@ -169,6 +169,7 @@ def main(): iters, num_samples, start_time = 0, 0, time.time() accuracy = fluid.average.WeightedAverage() + train_exe = fluid.ParallelExecutor(use_cuda=True, loss_name=avg_cost.name) for pass_id in range(args.pass_num): accuracy.reset() train_accs = [] @@ -184,14 +185,15 @@ def main(): y_data = np.array(map(lambda x: x[1], data)).astype("int64") y_data = y_data.reshape([-1, 1]) - loss, acc, weight = exe.run( - fluid.default_main_program(), + loss, acc, weight = train_exe.run( feed={"pixel": img_data, "label": y_data}, - fetch_list=[avg_cost, batch_acc, batch_size_tensor]) - accuracy.add(value=acc, weight=weight) + fetch_list=[avg_cost.name, batch_acc.name, batch_size_tensor.name]) + accuracy.add(value=np.array(np.mean(acc)), weight=np.mean(weight)) iters += 1 num_samples += len(y_data) + loss = np.mean(np.array(loss)) + acc = np.mean(np.array(acc)) print( "Pass = %d, Iter = %d, Loss = %f, Accuracy = %f" % (pass_id, iters, loss, acc) -- GitLab