From 3d884906b6616a03eee7d872a1ccafb0c3199b0f Mon Sep 17 00:00:00 2001 From: ying Date: Fri, 19 Jan 2018 15:21:43 +0800 Subject: [PATCH] format fluid example. --- fluid/image_classification/se_resnext.py | 25 ++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/fluid/image_classification/se_resnext.py b/fluid/image_classification/se_resnext.py index 36e40429..46b938f1 100644 --- a/fluid/image_classification/se_resnext.py +++ b/fluid/image_classification/se_resnext.py @@ -1,4 +1,5 @@ import os + import paddle.v2 as paddle import paddle.v2.fluid as fluid import reader @@ -21,10 +22,12 @@ def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, def squeeze_excitation(input, num_channels, reduction_ratio): pool = fluid.layers.pool2d( input=input, pool_size=0, pool_type='avg', global_pooling=True) - squeeze = fluid.layers.fc( - input=pool, size=num_channels / reduction_ratio, act='relu') - excitation = fluid.layers.fc( - input=squeeze, size=num_channels, act='sigmoid') + squeeze = fluid.layers.fc(input=pool, + size=num_channels / reduction_ratio, + act='relu') + excitation = fluid.layers.fc(input=squeeze, + size=num_channels, + act='sigmoid') scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) return scale @@ -129,20 +132,18 @@ def train(learning_rate, batch_size, num_passes, model_save_dir='model'): for pass_id in range(num_passes): accuracy.reset(exe) for batch_id, data in enumerate(train_reader()): - loss, acc = exe.run( - fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[avg_cost] + accuracy.metrics) + loss, acc = exe.run(fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[avg_cost] + accuracy.metrics) print("Pass {0}, batch {1}, loss {2}, acc {3}".format( pass_id, batch_id, loss[0], acc[0])) pass_acc = accuracy.eval(exe) test_accuracy.reset(exe) for data in test_reader(): - out, acc = exe.run( - inference_program, - feed=feeder.feed(data), - fetch_list=[avg_cost] + test_accuracy.metrics) + out, acc = exe.run(inference_program, + feed=feeder.feed(data), + fetch_list=[avg_cost] + test_accuracy.metrics) test_pass_acc = test_accuracy.eval(exe) print("End pass {0}, train_acc {1}, test_acc {2}".format( pass_id, pass_acc, test_pass_acc)) -- GitLab