diff --git a/python/paddle/fluid/tests/book/image_classification/notest_image_classification_resnet.py b/python/paddle/fluid/tests/book/image_classification/notest_image_classification_resnet.py index 5cbfdef91a64ae7c58d060edfb7b9f3bc8160f2b..17db38797cf19ae387f69f66daa42fc78cfcb7d5 100644 --- a/python/paddle/fluid/tests/book/image_classification/notest_image_classification_resnet.py +++ b/python/paddle/fluid/tests/book/image_classification/notest_image_classification_resnet.py @@ -64,15 +64,14 @@ def resnet_cifar10(input, depth=32): res3 = layer_warp(basicblock, res2, 32, 64, n, 2) pool = fluid.layers.pool2d( input=res3, pool_size=8, pool_type='avg', pool_stride=1) - return pool + predict = fluid.layers.fc(input=pool, size=10, act='softmax') + return predict def inference_network(): - classdim = 10 data_shape = [3, 32, 32] images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') - net = resnet_cifar10(images, 32) - predict = fluid.layers.fc(input=net, size=classdim, act='softmax') + predict = resnet_cifar10(images, 32) return predict diff --git a/python/paddle/fluid/tests/book/image_classification/notest_image_classification_vgg.py b/python/paddle/fluid/tests/book/image_classification/notest_image_classification_vgg.py index 8a6a5ff61a913ad6cbc609f8376afcbc621d60e2..e83afeed2f72635a40aa2ac21dc0c8611c309de4 100644 --- a/python/paddle/fluid/tests/book/image_classification/notest_image_classification_vgg.py +++ b/python/paddle/fluid/tests/book/image_classification/notest_image_classification_vgg.py @@ -43,15 +43,14 @@ def vgg16_bn_drop(input): bn = fluid.layers.batch_norm(input=fc1, act='relu') drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5) fc2 = fluid.layers.fc(input=drop2, size=4096, act=None) - return fc2 + predict = fluid.layers.fc(input=fc2, size=10, act='softmax') + return predict def inference_network(): - classdim = 10 data_shape = [3, 32, 32] images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') - net = vgg16_bn_drop(images) - predict = fluid.layers.fc(input=net, size=classdim, act='softmax') + predict = vgg16_bn_drop(images) return predict