diff --git a/03.image_classification/resnet.py b/03.image_classification/resnet.py index c60d19fc59dfea31d8a9b22d974047f60475b092..ab475907db0a6187d3a87cb4f5b3604652f7c8b4 100644 --- a/03.image_classification/resnet.py +++ b/03.image_classification/resnet.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle.v2 as paddle +import paddle.fluid as fluid __all__ = ['resnet_cifar10'] @@ -22,37 +22,35 @@ def conv_bn_layer(input, filter_size, stride, padding, - active_type=paddle.activation.Relu(), - ch_in=None): - tmp = paddle.layer.img_conv( + act='relu', + bias_attr=False): + tmp = fluid.layers.conv2d( input=input, filter_size=filter_size, - num_channels=ch_in, num_filters=ch_out, stride=stride, padding=padding, - act=paddle.activation.Linear(), - bias_attr=False) - return paddle.layer.batch_norm(input=tmp, act=active_type) + act=None, + bias_attr=bias_attr) + return fluid.layers.batch_norm(input=tmp, act=act) -def shortcut(ipt, ch_in, ch_out, stride): +def shortcut(input, ch_in, ch_out, stride): if ch_in != ch_out: - return conv_bn_layer(ipt, ch_out, 1, stride, 0, - paddle.activation.Linear()) + return conv_bn_layer(input, ch_out, 1, stride, 0, None) else: - return ipt + return input -def basicblock(ipt, ch_in, ch_out, stride): - tmp = conv_bn_layer(ipt, ch_out, 3, stride, 1) - tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, paddle.activation.Linear()) - short = shortcut(ipt, ch_in, ch_out, stride) - return paddle.layer.addto(input=[tmp, short], act=paddle.activation.Relu()) +def basicblock(input, ch_in, ch_out, stride): + tmp = conv_bn_layer(input, ch_out, 3, stride, 1) + tmp = conv_bn_layer(tmp, ch_out, 3, 1, 1, act=None, bias_attr=True) + short = shortcut(input, ch_in, ch_out, stride) + return fluid.layers.elementwise_add(x=tmp, y=short, act='relu') -def layer_warp(block_func, ipt, ch_in, ch_out, count, stride): - tmp = block_func(ipt, ch_in, ch_out, stride) +def layer_warp(block_func, input, ch_in, ch_out, count, stride): + tmp = block_func(input, ch_in, ch_out, stride) for i in range(1, count): tmp = block_func(tmp, ch_out, ch_out, 1) return tmp @@ -63,11 +61,11 @@ def resnet_cifar10(ipt, depth=32): assert (depth - 2) % 6 == 0 n = (depth - 2) / 6 nStages = {16, 64, 128} - conv1 = conv_bn_layer( - ipt, ch_in=3, ch_out=16, filter_size=3, stride=1, padding=1) + conv1 = conv_bn_layer(ipt, ch_out=16, filter_size=3, stride=1, padding=1) res1 = layer_warp(basicblock, conv1, 16, 16, n, 1) res2 = layer_warp(basicblock, res1, 16, 32, n, 2) res3 = layer_warp(basicblock, res2, 32, 64, n, 2) - pool = paddle.layer.img_pool( - input=res3, pool_size=8, stride=1, pool_type=paddle.pooling.Avg()) - return pool + pool = fluid.layers.pool2d( + input=res3, pool_size=8, pool_type='avg', pool_stride=1) + predict = fluid.layers.fc(input=pool, size=10, act='softmax') + return predict diff --git a/03.image_classification/train.py b/03.image_classification/train.py index faafc7ff5038cd8b40944d7742a4d1612468f80b..8c0a2ceed1836a1fff471a9a87f73a9370abd742 100644 --- a/03.image_classification/train.py +++ b/03.image_classification/train.py @@ -12,92 +12,84 @@ # See the License for the specific language governing permissions and # limitations under the License -import sys, os +from __future__ import print_function -import paddle.v2 as paddle +import paddle +import paddle.fluid as fluid +import numpy +import sys from vgg import vgg_bn_drop from resnet import resnet_cifar10 -with_gpu = os.getenv('WITH_GPU', '0') != '0' +def inference_network(): + # The image is 32 * 32 with RGB representation. + data_shape = [3, 32, 32] + images = fluid.layers.data(name='pixel', shape=data_shape, dtype='float32') -def main(): - datadim = 3 * 32 * 32 - classdim = 10 + predict = resnet_cifar10(images, 32) + # predict = vgg_bn_drop(images) # un-comment to use vgg net + return predict - # PaddlePaddle init - paddle.init(use_gpu=with_gpu, trainer_count=1) - image = paddle.layer.data( - name="image", type=paddle.data_type.dense_vector(datadim)) +def train_network(): + predict = inference_network() + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(cost) + accuracy = fluid.layers.accuracy(input=predict, label=label) + return [avg_cost, accuracy] - # Add neural network config - # option 1. resnet - # net = resnet_cifar10(image, depth=32) - # option 2. vgg - net = vgg_bn_drop(image) - out = paddle.layer.fc( - input=net, size=classdim, act=paddle.activation.Softmax()) +def train(use_cuda, train_program, params_dirname): + BATCH_SIZE = 128 + EPOCH_NUM = 2 - lbl = paddle.layer.data( - name="label", type=paddle.data_type.integer_value(classdim)) - cost = paddle.layer.classification_cost(input=out, label=lbl) + train_reader = paddle.batch( + paddle.reader.shuffle(paddle.dataset.cifar.train10(), buf_size=50000), + batch_size=BATCH_SIZE) - # Create parameters - parameters = paddle.parameters.create(cost) + test_reader = paddle.batch( + paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE) - # Create optimizer - momentum_optimizer = paddle.optimizer.Momentum( - momentum=0.9, - regularization=paddle.optimizer.L2Regularization(rate=0.0002 * 128), - learning_rate=0.1 / 128.0, - learning_rate_decay_a=0.1, - learning_rate_decay_b=50000 * 100, - learning_rate_schedule='discexp') - - # Create trainer - trainer = paddle.trainer.SGD( - cost=cost, parameters=parameters, update_equation=momentum_optimizer) - - # End batch and end pass event handler def event_handler(event): - if isinstance(event, paddle.event.EndIteration): - if event.batch_id % 100 == 0: - print "\nPass %d, Batch %d, Cost %f, %s" % ( - event.pass_id, event.batch_id, event.cost, event.metrics) + if isinstance(event, fluid.EndStepEvent): + if event.step % 100 == 0: + print("Pass %d, Batch %d, Cost %f, Acc %f" % + (event.step, event.epoch, event.metrics[0], + event.metrics[1])) else: sys.stdout.write('.') sys.stdout.flush() - if isinstance(event, paddle.event.EndPass): - # save parameters - with open('params_pass_%d.tar' % event.pass_id, 'w') as f: - trainer.save_parameter_to_tar(f) - - result = trainer.test( - reader=paddle.batch( - paddle.dataset.cifar.test10(), batch_size=128), - feeding={'image': 0, - 'label': 1}) - print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics) - - # Save the inference topology to protobuf. - inference_topology = paddle.topology.Topology(layers=out) - with open("inference_topology.pkl", 'wb') as f: - inference_topology.serialize_for_inference(f) + + if isinstance(event, fluid.EndEpochEvent): + avg_cost, accuracy = trainer.test( + reader=test_reader, feed_order=['pixel', 'label']) + + print('Loss {0:2.2}, Acc {1:2.2}'.format(avg_cost, accuracy)) + if params_dirname is not None: + trainer.save_params(params_dirname) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + trainer = fluid.Trainer( + train_func=train_program, + optimizer=fluid.optimizer.Adam(learning_rate=0.001), + place=place) trainer.train( - reader=paddle.batch( - paddle.reader.shuffle( - paddle.dataset.cifar.train10(), buf_size=50000), - batch_size=128), - num_passes=200, + reader=train_reader, + num_epochs=EPOCH_NUM, event_handler=event_handler, - feeding={'image': 0, - 'label': 1}) + feed_order=['pixel', 'label']) - # inference + +def infer(use_cuda, inference_program, params_dirname=None): + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + inferencer = fluid.Inferencer( + infer_func=inference_program, param_path=params_dirname, place=place) + + # Prepare testing data. from PIL import Image import numpy as np import os @@ -105,6 +97,7 @@ def main(): def load_image(file): im = Image.open(file) im = im.resize((32, 32), Image.ANTIALIAS) + im = np.array(im).astype(np.float32) # The storage order of the loaded image is W(widht), # H(height), C(channel). PaddlePaddle requires @@ -114,23 +107,38 @@ def main(): # image is B(Blue), G(green), R(Red). But PIL open # image in RGB mode. It must swap the channel order. im = im[(2, 1, 0), :, :] # BGR - im = im.flatten() im = im / 255.0 + + # Add one dimension to mimic the list format. + im = numpy.expand_dims(im, axis=0) return im - test_data = [] cur_dir = os.path.dirname(os.path.realpath(__file__)) - test_data.append((load_image(cur_dir + '/image/dog.png'), )) + img = load_image(cur_dir + '/image/dog.png') + + # inference + results = inferencer.infer({'pixel': img}) + + print("infer results: ", results) + + +def main(use_cuda): + if use_cuda and not fluid.core.is_compiled_with_cuda(): + return + save_path = "image_classification_resnet.inference.model" - # users can remove the comments and change the model name - # with open('params_pass_50.tar', 'r') as f: - # parameters = paddle.parameters.Parameters.from_tar(f) + train( + use_cuda=use_cuda, + train_program=train_network, + params_dirname=save_path) - probs = paddle.infer( - output_layer=out, parameters=parameters, input=test_data) - lab = np.argsort(-probs) # probs and lab are the results of one batch data - print "Label of image/dog.png is: %d" % lab[0][0] + infer( + use_cuda=use_cuda, + inference_program=inference_network, + params_dirname=save_path) if __name__ == '__main__': - main() + # For demo purpose, the training runs on CPU + # Please change accordingly. + main(use_cuda=False) diff --git a/03.image_classification/vgg.py b/03.image_classification/vgg.py index 1e0e6b93adde30425f17aa9cd07542275f4fec37..9f0f697bace3e432fed03f1df0aa142aab802e30 100644 --- a/03.image_classification/vgg.py +++ b/03.image_classification/vgg.py @@ -12,36 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle.v2 as paddle +import paddle +import paddle.fluid as fluid __all__ = ['vgg_bn_drop'] def vgg_bn_drop(input): - def conv_block(ipt, num_filter, groups, dropouts, num_channels=None): - return paddle.networks.img_conv_group( + def conv_block(ipt, num_filter, groups, dropouts): + return fluid.nets.img_conv_group( input=ipt, - num_channels=num_channels, pool_size=2, pool_stride=2, conv_num_filter=[num_filter] * groups, conv_filter_size=3, - conv_act=paddle.activation.Relu(), + conv_act='relu', conv_with_batchnorm=True, conv_batchnorm_drop_rate=dropouts, - pool_type=paddle.pooling.Max()) + pool_type='max') - conv1 = conv_block(input, 64, 2, [0.3, 0], 3) + conv1 = conv_block(input, 64, 2, [0.3, 0]) conv2 = conv_block(conv1, 128, 2, [0.4, 0]) conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0]) conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0]) conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0]) - drop = paddle.layer.dropout(input=conv5, dropout_rate=0.5) - fc1 = paddle.layer.fc(input=drop, size=512, act=paddle.activation.Linear()) - bn = paddle.layer.batch_norm( - input=fc1, - act=paddle.activation.Relu(), - layer_attr=paddle.attr.Extra(drop_rate=0.5)) - fc2 = paddle.layer.fc(input=bn, size=512, act=paddle.activation.Linear()) - return fc2 + drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5) + fc1 = fluid.layers.fc(input=drop, size=512, act=None) + bn = fluid.layers.batch_norm(input=fc1, act='relu') + drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5) + fc2 = fluid.layers.fc(input=drop2, size=512, act=None) + predict = fluid.layers.fc(input=fc2, size=10, act='softmax') + return predict \ No newline at end of file