diff --git a/fluid/image_classification/se_resnext.py b/fluid/image_classification/se_resnext.py index 99a62347dadd9fb0ca3d3a85808edc5e5976cc5b..9d7b4ca2be6979ef0ceaf2a5636d8162fe3fc821 100644 --- a/fluid/image_classification/se_resnext.py +++ b/fluid/image_classification/se_resnext.py @@ -103,66 +103,87 @@ def train(learning_rate, batch_size, num_passes, init_model=None, - model_save_dir='model'): + model_save_dir='model', + parallel=True): class_dim = 1000 image_shape = [3, 224, 224] image = fluid.layers.data(name='image', shape=image_shape, dtype='float32') label = fluid.layers.data(name='label', shape=[1], dtype='int64') - out = SE_ResNeXt(input=image, class_dim=class_dim) - - cost = fluid.layers.cross_entropy(input=out, label=label) - avg_cost = fluid.layers.mean(x=cost) + if parallel: + places = fluid.layers.get_places() + pd = fluid.layers.ParallelDo(places) + + with pd.do(): + image_ = pd.read_input(image) + label_ = pd.read_input(label) + out = SE_ResNeXt(input=image_, class_dim=class_dim) + cost = fluid.layers.cross_entropy(input=out, label=label_) + avg_cost = fluid.layers.mean(x=cost) + accuracy = fluid.layers.accuracy(input=out, label=label_) + pd.write_output(avg_cost) + pd.write_output(accuracy) + + avg_cost, accuracy = pd() + avg_cost = fluid.layers.mean(x=avg_cost) + accuracy = fluid.layers.mean(x=accuracy) + else: + out = SE_ResNeXt(input=image, class_dim=class_dim) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + accuracy = fluid.layers.accuracy(input=out, label=label) optimizer = fluid.optimizer.Momentum( learning_rate=learning_rate, momentum=0.9, regularization=fluid.regularizer.L2Decay(1e-4)) opts = optimizer.minimize(avg_cost) - accuracy = fluid.evaluator.Accuracy(input=out, label=label) inference_program = fluid.default_main_program().clone() with fluid.program_guard(inference_program): - test_accuracy = fluid.evaluator.Accuracy(input=out, label=label) - test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states - inference_program = fluid.io.get_inference_program(test_target) + inference_program = fluid.io.get_inference_program([avg_cost, accuracy]) place = fluid.CUDAPlace(0) exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) if init_model is not None: - fluid.io.load_persistables_if_exist(exe, init_model) + fluid.io.load_persistables(exe, init_model) train_reader = paddle.batch(reader.train(), batch_size=batch_size) test_reader = paddle.batch(reader.test(), batch_size=batch_size) feeder = fluid.DataFeeder(place=place, feed_list=[image, label]) for pass_id in range(num_passes): - accuracy.reset(exe) for batch_id, data in enumerate(train_reader()): - loss, acc = exe.run(fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[avg_cost] + accuracy.metrics) - print("Pass {0}, batch {1}, loss {2}, acc {3}".format( - pass_id, batch_id, loss[0], acc[0])) - pass_acc = accuracy.eval(exe) - - test_accuracy.reset(exe) + loss = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost]) + print("Pass {0}, batch {1}, loss {2}".format(pass_id, batch_id, + float(loss[0]))) + + total_loss = 0.0 + total_acc = 0.0 + total_batch = 0 for data in test_reader(): loss, acc = exe.run(inference_program, feed=feeder.feed(data), - fetch_list=[avg_cost] + test_accuracy.metrics) - test_pass_acc = test_accuracy.eval(exe) - print("End pass {0}, train_acc {1}, test_acc {2}".format( - pass_id, pass_acc, test_pass_acc)) + fetch_list=[avg_cost, accuracy]) + total_loss += float(loss) + total_acc += float(acc) + total_batch += 1 + print("End pass {0}, test_loss {1}, test_acc {2}".format( + pass_id, total_loss / total_batch, total_acc / total_batch)) model_path = os.path.join(model_save_dir, str(pass_id)) - if not os.path.isdir(model_path): - os.makedirs(model_path) - fluid.io.save_persistables(exe, model_path) + fluid.io.save_inference_model(model_path, ['image'], [out], exe) if __name__ == '__main__': - train(learning_rate=0.1, batch_size=8, num_passes=100, init_model=None) + train( + learning_rate=0.1, + batch_size=8, + num_passes=100, + init_model=None, + parallel=False)