提交 3d884906 编写于 作者: Y ying

format fluid example.

上级 eaa61313
import os import os
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.fluid as fluid import paddle.v2.fluid as fluid
import reader import reader
...@@ -21,10 +22,12 @@ def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1, ...@@ -21,10 +22,12 @@ def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1,
def squeeze_excitation(input, num_channels, reduction_ratio): def squeeze_excitation(input, num_channels, reduction_ratio):
pool = fluid.layers.pool2d( pool = fluid.layers.pool2d(
input=input, pool_size=0, pool_type='avg', global_pooling=True) input=input, pool_size=0, pool_type='avg', global_pooling=True)
squeeze = fluid.layers.fc( squeeze = fluid.layers.fc(input=pool,
input=pool, size=num_channels / reduction_ratio, act='relu') size=num_channels / reduction_ratio,
excitation = fluid.layers.fc( act='relu')
input=squeeze, size=num_channels, act='sigmoid') excitation = fluid.layers.fc(input=squeeze,
size=num_channels,
act='sigmoid')
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
return scale return scale
...@@ -129,20 +132,18 @@ def train(learning_rate, batch_size, num_passes, model_save_dir='model'): ...@@ -129,20 +132,18 @@ def train(learning_rate, batch_size, num_passes, model_save_dir='model'):
for pass_id in range(num_passes): for pass_id in range(num_passes):
accuracy.reset(exe) accuracy.reset(exe)
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
loss, acc = exe.run( loss, acc = exe.run(fluid.default_main_program(),
fluid.default_main_program(), feed=feeder.feed(data),
feed=feeder.feed(data), fetch_list=[avg_cost] + accuracy.metrics)
fetch_list=[avg_cost] + accuracy.metrics)
print("Pass {0}, batch {1}, loss {2}, acc {3}".format( print("Pass {0}, batch {1}, loss {2}, acc {3}".format(
pass_id, batch_id, loss[0], acc[0])) pass_id, batch_id, loss[0], acc[0]))
pass_acc = accuracy.eval(exe) pass_acc = accuracy.eval(exe)
test_accuracy.reset(exe) test_accuracy.reset(exe)
for data in test_reader(): for data in test_reader():
out, acc = exe.run( out, acc = exe.run(inference_program,
inference_program, feed=feeder.feed(data),
feed=feeder.feed(data), fetch_list=[avg_cost] + test_accuracy.metrics)
fetch_list=[avg_cost] + test_accuracy.metrics)
test_pass_acc = test_accuracy.eval(exe) test_pass_acc = test_accuracy.eval(exe)
print("End pass {0}, train_acc {1}, test_acc {2}".format( print("End pass {0}, train_acc {1}, test_acc {2}".format(
pass_id, pass_acc, test_pass_acc)) pass_id, pass_acc, test_pass_acc))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册