diff --git a/fluid/icnet/README.md b/fluid/icnet/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..56954720bbe672b4a80c9d58150dc958e5e7680c
--- /dev/null
+++ b/fluid/icnet/README.md
@@ -0,0 +1,110 @@
+运行本目录下的程序示例需要使用PaddlePaddle develop最新版本。如果您的PaddlePaddle安装版本低于此要求,请按照[安装文档](http://www.paddlepaddle.org/docs/develop/documentation/zh/build_and_install/pip_install_cn.html)中的说明更新PaddlePaddle安装版本。
+
+
+## 代码结构
+```
+├── network.py # 网络结构定义脚本
+├── train.py # 训练任务脚本
+├── eval.py # 评估脚本
+├── infer.py # 预测脚本
+├── cityscape.py # 数据预处理脚本
+└── utils.py # 定义通用的函数
+```
+
+## 简介
+
+Image Cascade Network(ICNet)主要用于图像实时语义分割。相较于其它压缩计算的方法,ICNet即考虑了速度,也考虑了准确性。
+ICNet的主要思想是将输入图像变换为不同的分辨率,然后用不同计算复杂度的子网络计算不同分辨率的输入,然后将结果合并。ICNet由三个子网络组成,计算复杂度高的网络处理低分辨率输入,计算复杂度低的网络处理分辨率高的网络,通过这种方式在高分辨率图像的准确性和低复杂度网络的效率之间获得平衡。
+
+整个网络结构如下:
+
+
+
+图 1
+
+
+
+## 数据准备
+
+
+
+本文采用Cityscape数据集,请前往[Cityscape官网](https://www.cityscapes-dataset.com)注册下载。下载数据之后,按照[这里](https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/createTrainIdLabelImgs.py#L3)的说明和工具处理数据。
+处理之后的数据
+```
+data/cityscape/
+|-- gtFine
+| |-- test
+| |-- train
+| `-- val
+|-- leftImg8bit
+| |-- test
+| |-- train
+| `-- val
+|-- train.list
+`-- val.list
+```
+其中,train.list和val.list分别是用于训练和测试的列表文件,第一列为输入图像数据,第二列为标注数据,两列用空格分开。示例如下:
+```
+leftImg8bit/train/stuttgart/stuttgart_000021_000019_leftImg8bit.png gtFine/train/stuttgart/stuttgart_000021_000019_gtFine_labelTrainIds.png
+leftImg8bit/train/stuttgart/stuttgart_000072_000019_leftImg8bit.png gtFine/train/stuttgart/stuttgart_000072_000019_gtFine_labelTrainIds.png
+```
+完成数据下载和准备后,需要修改`cityscape.py`脚本中对应的数据地址。
+
+## 模型训练与预测
+
+### 训练
+执行以下命令进行训练:
+```
+python train.py --batch_size=16 --use_gpu=True
+```
+使用以下命令获得更多使用说明:
+```
+python train.py --help
+```
+训练过程中会根据用户的设置,输出训练集上每个网络分支的`loss`, 示例如下:
+```
+Iter[0]; train loss: 2.338; sub4_loss: 3.367; sub24_loss: 4.120; sub124_loss: 0.151
+```
+### 测试
+执行以下命令在`Cityscape`测试数据集上进行测试:
+```
+python eval.py --model_path="./model/" --use_gpu=True
+```
+需要通过选项`--model_path`指定模型文件。
+测试脚本的输出的评估指标为[mean IoU]()。
+
+### 预测
+执行以下命令对指定的数据进行预测:
+```
+python infer.py \
+--model_path="./model" \
+--images_path="./data/cityscape/" \
+--images_list="./data/cityscape/infer.list"
+```
+通过选项`--images_list`指定列表文件,列表文件中每一行为一个要预测的图片的路径。
+预测结果默认保存到当前路径下的`output`文件夹下。
+
+## 实验结果
+图2为在`CityScape`训练集上的训练的Loss曲线:
+
+
+
+图 2
+
+
+在训练集上训练,在validation数据集上验证的结果为:mean_IoU=67.0%(论文67.7%)
+
+图3是使用`infer.py`脚本预测产生的结果示例,其中,第一行为输入的原始图片,第二行为人工的标注,第三行为我们模型计算的结果。
+
+
+图 3
+
+
+## 其他信息
+|数据集 | pretrained model |
+|---|---|
+|CityScape | [Model]()[md: ] |
+
+## 参考
+
+- [ICNet for Real-Time Semantic Segmentation on High-Resolution Images](https://arxiv.org/abs/1704.08545)
diff --git a/fluid/icnet/cityscape.py b/fluid/icnet/cityscape.py
new file mode 100644
index 0000000000000000000000000000000000000000..3288b7f1e178850e6dec99adb2efbb89cbaf8f11
--- /dev/null
+++ b/fluid/icnet/cityscape.py
@@ -0,0 +1,236 @@
+"""Reader for Cityscape dataset.
+"""
+import os
+import cv2
+import numpy as np
+import paddle.v2 as paddle
+
+DATA_PATH = "./data/cityscape"
+TRAIN_LIST = DATA_PATH + "/train.list"
+TEST_LIST = DATA_PATH + "/val.list"
+IGNORE_LABEL = 255
+NUM_CLASSES = 19
+TRAIN_DATA_SHAPE = (3, 720, 720)
+TEST_DATA_SHAPE = (3, 1024, 2048)
+IMG_MEAN = np.array((103.939, 116.779, 123.68), dtype=np.float32)
+
+
+def train_data_shape():
+ return TRAIN_DATA_SHAPE
+
+
+def test_data_shape():
+ return TEST_DATA_SHAPE
+
+
+def num_classes():
+ return NUM_CLASSES
+
+
+class DataGenerater:
+ def __init__(self, data_list, mode="train", flip=True, scaling=True):
+ self.flip = flip
+ self.scaling = scaling
+ self.image_label = []
+ with open(data_list, 'r') as f:
+ for line in f:
+ image_file, label_file = line.strip().split(' ')
+ self.image_label.append((image_file, label_file))
+
+ def create_train_reader(self, batch_size):
+ """
+ Create a reader for train dataset.
+ """
+
+ def reader():
+ np.random.shuffle(self.image_label)
+ images = []
+ labels_sub1 = []
+ labels_sub2 = []
+ labels_sub4 = []
+ count = 0
+ for image, label in self.image_label:
+ image, label_sub1, label_sub2, label_sub4 = self.process_train_data(
+ image, label)
+ count += 1
+ images.append(image)
+ labels_sub1.append(label_sub1)
+ labels_sub2.append(label_sub2)
+ labels_sub4.append(label_sub4)
+ if count == batch_size:
+ yield self.mask(
+ np.array(images),
+ np.array(labels_sub1),
+ np.array(labels_sub2), np.array(labels_sub4))
+ images = []
+ labels_sub1 = []
+ labels_sub2 = []
+ labels_sub4 = []
+ count = 0
+ if images:
+ yield self.mask(
+ np.array(images),
+ np.array(labels_sub1),
+ np.array(labels_sub2), np.array(labels_sub4))
+
+ return reader
+
+ def create_test_reader(self):
+ """
+ Create a reader for test dataset.
+ """
+
+ def reader():
+ for image, label in self.image_label:
+ image, label = self.load(image, label)
+ image = paddle.image.to_chw(image)[np.newaxis, :]
+ label = label[np.newaxis, :, :, np.newaxis].astype("float32")
+ label_mask = np.where((label != IGNORE_LABEL).flatten())[
+ 0].astype("int32")
+ yield image, label, label_mask
+
+ return reader
+
+ def process_train_data(self, image, label):
+ """
+ Process training data.
+ """
+ image, label = self.load(image, label)
+ if self.flip:
+ image, label = self.random_flip(image, label)
+ if self.scaling:
+ image, label = self.random_scaling(image, label)
+ image, label = self.resize(image, label, out_size=TRAIN_DATA_SHAPE[1:])
+ label = label.astype("float32")
+ label_sub1 = paddle.image.to_chw(self.scale_label(label, factor=4))
+ label_sub2 = paddle.image.to_chw(self.scale_label(label, factor=8))
+ label_sub4 = paddle.image.to_chw(self.scale_label(label, factor=16))
+ image = paddle.image.to_chw(image)
+ return image, label_sub1, label_sub2, label_sub4
+
+ def load(self, image, label):
+ """
+ Load image from file.
+ """
+ image = paddle.image.load_image(
+ DATA_PATH + "/" + image, is_color=True).astype("float32")
+ image -= IMG_MEAN
+ label = paddle.image.load_image(
+ DATA_PATH + "/" + label, is_color=False).astype("float32")
+ return image, label
+
+ def random_flip(self, image, label):
+ """
+ Flip image and label randomly.
+ """
+ r = np.random.rand(1)
+ if r > 0.5:
+ image = paddle.image.left_right_flip(image, is_color=True)
+ label = paddle.image.left_right_flip(label, is_color=False)
+ return image, label
+
+ def random_scaling(self, image, label):
+ """
+ Scale image and label randomly.
+ """
+ scale = np.random.uniform(0.5, 2.0, 1)[0]
+ h_new = int(image.shape[0] * scale)
+ w_new = int(image.shape[1] * scale)
+ image = cv2.resize(image, (w_new, h_new))
+ label = cv2.resize(
+ label, (w_new, h_new), interpolation=cv2.INTER_NEAREST)
+ return image, label
+
+ def padding_as(self, image, h, w, is_color):
+ """
+ Padding image.
+ """
+ pad_h = max(image.shape[0], h) - image.shape[0]
+ pad_w = max(image.shape[1], w) - image.shape[1]
+ if is_color:
+ return np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), 'constant')
+ else:
+ return np.pad(image, ((0, pad_h), (0, pad_w)), 'constant')
+
+ def resize(self, image, label, out_size):
+ """
+ Resize image and label by padding or cropping.
+ """
+ ignore_label = IGNORE_LABEL
+ label = label - ignore_label
+ if len(label.shape) == 2:
+ label = label[:, :, np.newaxis]
+ combined = np.concatenate((image, label), axis=2)
+ combined = self.padding_as(
+ combined, out_size[0], out_size[1], is_color=True)
+ combined = paddle.image.random_crop(
+ combined, out_size[0], is_color=True)
+ image = combined[:, :, 0:3]
+ label = combined[:, :, 3:4] + ignore_label
+ return image, label
+
+ def scale_label(self, label, factor):
+ """
+ Scale label according to factor.
+ """
+ h = label.shape[0] / factor
+ w = label.shape[1] / factor
+ return cv2.resize(
+ label, (h, w), interpolation=cv2.INTER_NEAREST)[:, :, np.newaxis]
+
+ def mask(self, image, label0, label1, label2):
+ """
+ Get mask for valid pixels.
+ """
+ mask_sub1 = np.where(((label0 < (NUM_CLASSES + 1)) & (
+ label0 != IGNORE_LABEL)).flatten())[0].astype("int32")
+ mask_sub2 = np.where(((label1 < (NUM_CLASSES + 1)) & (
+ label1 != IGNORE_LABEL)).flatten())[0].astype("int32")
+ mask_sub4 = np.where(((label2 < (NUM_CLASSES + 1)) & (
+ label2 != IGNORE_LABEL)).flatten())[0].astype("int32")
+ return image.astype(
+ "float32"), label0, mask_sub1, label1, mask_sub2, label2, mask_sub4
+
+
+def train(batch_size=32, flip=True, scaling=True):
+ """
+ Cityscape training set reader.
+ It returns a reader, in which each result is a batch with batch_size samples.
+
+ :param batch_size: The batch size of each result return by the reader.
+ :type batch_size: int
+ :param flip: Whether flip images randomly.
+ :type batch_size: bool
+ :param scaling: Whether scale images randomly.
+ :type batch_size: bool
+ :return: Training reader.
+ :rtype: callable
+ """
+ reader = DataGenerater(
+ TRAIN_LIST, flip=flip, scaling=scaling).create_train_reader(batch_size)
+ return reader
+
+
+def test():
+ """
+ Cityscape validation set reader.
+ It returns a reader, in which each result is a sample.
+
+ :return: Training reader.
+ :rtype: callable
+ """
+ reader = DataGenerater(TEST_LIST).create_test_reader()
+ return reader
+
+
+def infer(image_list=TEST_LIST):
+ """
+ Infer set reader.
+ It returns a reader, in which each result is a sample.
+
+ :param image_list: The image list file in which each line is a path of image to be infered.
+ :type batch_size: str
+ :return: Infer reader.
+ :rtype: callable
+ """
+ reader = DataGenerater(image_list).create_test_reader()
diff --git a/fluid/icnet/eval.py b/fluid/icnet/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3253c3cb63b8bb58d8a1bdad3318de1c1441142
--- /dev/null
+++ b/fluid/icnet/eval.py
@@ -0,0 +1,96 @@
+"""Evaluator for ICNet model."""
+import paddle.fluid as fluid
+import numpy as np
+from utils import add_arguments, print_arguments, get_feeder_data
+from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
+from paddle.fluid.initializer import init_on_cpu
+from icnet import icnet
+import cityscape
+import argparse
+import functools
+import sys
+import os
+
+parser = argparse.ArgumentParser(description=__doc__)
+add_arg = functools.partial(add_arguments, argparser=parser)
+# yapf: disable
+add_arg('model_path', str, None, "Model path.")
+add_arg('use_gpu', bool, True, "Whether use GPU to test.")
+# yapf: enable
+
+
+def cal_mean_iou(wrong, correct):
+ sum = wrong + cerroct
+ true_num = (sum != 0).sum()
+ for i in len(sum):
+ if sum[i] == 0:
+ sum[i] = 1
+ return (cerroct.astype("float64") / sum).sum() / true_num
+
+
+def create_iou(predict, label, mask, num_classes, image_shape):
+ predict = fluid.layers.resize_bilinear(predict, out_shape=image_shape[1:3])
+ predict = fluid.layers.transpose(predict, perm=[0, 2, 3, 1])
+ predict = fluid.layers.reshape(predict, shape=[-1, num_classes])
+ label = fluid.layers.reshape(label, shape=[-1, 1])
+ _, predict = fluid.layers.topk(predict, k=1)
+ predict = fluid.layers.cast(predict, dtype="float32")
+ predict = fluid.layers.gather(predict, mask)
+ label = fluid.layers.gather(label, mask)
+ label = fluid.layers.cast(label, dtype="int32")
+ predict = fluid.layers.cast(predict, dtype="int32")
+ iou, out_w, out_r = fluid.layers.mean_iou(predict, label, num_classes)
+ return iou, out_w, out_r
+
+
+def eval(args):
+ data_shape = cityscape.test_data_shape()
+ num_classes = cityscape.num_classes()
+ # define network
+ images = fluid.layers.data(name='image', shape=data_shape, dtype='float32')
+ label = fluid.layers.data(name='label', shape=[1], dtype='int32')
+ mask = fluid.layers.data(name='mask', shape=[-1], dtype='int32')
+
+ _, _, sub124_out = icnet(images, num_classes,
+ np.array(data_shape[1:]).astype("float32"))
+ iou, out_w, out_r = create_iou(sub124_out, label, mask, num_classes,
+ data_shape)
+ inference_program = fluid.default_main_program().clone(for_test=True)
+ # prepare environment
+ place = fluid.CPUPlace()
+ if args.use_gpu:
+ place = fluid.CUDAPlace(0)
+ exe = fluid.Executor(place)
+ exe.run(fluid.default_startup_program())
+ assert os.path.exists(args.model_path)
+ fluid.io.load_params(exe, args.model_path)
+ print "loaded model from: %s" % args.model_path
+ sys.stdout.flush()
+
+ fetch_vars = [iou, out_w, out_r]
+ out_wrong = np.zeros([num_classes]).astype("int64")
+ out_right = np.zeros([num_classes]).astype("int64")
+ count = 0
+ test_reader = cityscape.test()
+ for data in test_reader():
+ count += 1
+ result = exe.run(inference_program,
+ feed=get_feeder_data(
+ data, place, for_test=True),
+ fetch_list=fetch_vars)
+ out_wrong += result[1]
+ out_right += result[2]
+ print "count: %s; current iou: %.3f;\r" % (count, result[0]),
+ sys.stdout.flush()
+ iou = cal_mean_iou(out_wrong, out_right)
+ print "\nmean iou: %.3f" % iou
+
+
+def main():
+ args = parser.parse_args()
+ print_arguments(args)
+ eval(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fluid/icnet/icnet.py b/fluid/icnet/icnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..14eaa5fa25c8570cc8747842333c7ca72f104fd1
--- /dev/null
+++ b/fluid/icnet/icnet.py
@@ -0,0 +1,301 @@
+import paddle.fluid as fluid
+import numpy as np
+import sys
+
+
+def conv(input,
+ k_h,
+ k_w,
+ c_o,
+ s_h,
+ s_w,
+ relu=False,
+ padding="VALID",
+ biased=False,
+ name=None):
+ act = None
+ tmp = input
+ if relu:
+ act = "relu"
+ if padding == "SAME":
+ padding_h = max(k_h - s_h, 0)
+ padding_w = max(k_w - s_w, 0)
+ padding_top = padding_h / 2
+ padding_left = padding_w / 2
+ padding_bottom = padding_h - padding_top
+ padding_right = padding_w - padding_left
+ padding = [
+ 0, 0, 0, 0, padding_top, padding_bottom, padding_left, padding_right
+ ]
+ tmp = fluid.layers.pad(tmp, padding)
+ tmp = fluid.layers.conv2d(
+ tmp,
+ num_filters=c_o,
+ filter_size=[k_h, k_w],
+ stride=[s_h, s_w],
+ groups=1,
+ act=act,
+ bias_attr=biased,
+ use_cudnn=False,
+ name=name)
+ return tmp
+
+
+def atrous_conv(input,
+ k_h,
+ k_w,
+ c_o,
+ dilation,
+ relu=False,
+ padding="VALID",
+ biased=False,
+ name=None):
+ act = None
+ if relu:
+ act = "relu"
+ tmp = input
+ if padding == "SAME":
+ padding_h = max(k_h - s_h, 0)
+ padding_w = max(k_w - s_w, 0)
+ padding_top = padding_h / 2
+ padding_left = padding_w / 2
+ padding_bottom = padding_h - padding_top
+ padding_right = padding_w - padding_left
+ padding = [
+ 0, 0, 0, 0, padding_top, padding_bottom, padding_left, padding_right
+ ]
+ tmp = fluid.layers.pad(tmp, padding)
+
+ tmp = fluid.layers.conv2d(
+ input,
+ num_filters=c_o,
+ filter_size=[k_h, k_w],
+ dilation=dilation,
+ groups=1,
+ act=act,
+ bias_attr=biased,
+ use_cudnn=False,
+ name=name)
+ return tmp
+
+
+def zero_padding(input, padding):
+ return fluid.layers.pad(input,
+ [0, 0, 0, 0, padding, padding, padding, padding])
+
+
+def bn(input, relu=False, name=None, is_test=False):
+ act = None
+ if relu:
+ act = 'relu'
+ name = input.name.split(".")[0] + "_bn"
+ tmp = fluid.layers.batch_norm(
+ input, act=act, momentum=0.95, epsilon=1e-5, name=name)
+ return tmp
+
+
+def avg_pool(input, k_h, k_w, s_h, s_w, name=None, padding=0):
+ temp = fluid.layers.pool2d(
+ input,
+ pool_size=[k_h, k_w],
+ pool_type="avg",
+ pool_stride=[s_h, s_w],
+ pool_padding=padding,
+ name=name)
+ return temp
+
+
+def max_pool(input, k_h, k_w, s_h, s_w, name=None, padding=0):
+ temp = fluid.layers.pool2d(
+ input,
+ pool_size=[k_h, k_w],
+ pool_type="max",
+ pool_stride=[s_h, s_w],
+ pool_padding=padding,
+ name=name)
+ return temp
+
+
+def interp(input, out_shape):
+ out_shape = list(out_shape.astype("int32"))
+ return fluid.layers.resize_bilinear(input, out_shape=out_shape)
+
+
+def dilation_convs(input):
+ tmp = res_block(input, filter_num=256, padding=1, name="conv3_2")
+ tmp = res_block(tmp, filter_num=256, padding=1, name="conv3_3")
+ tmp = res_block(tmp, filter_num=256, padding=1, name="conv3_4")
+
+ tmp = proj_block(tmp, filter_num=512, padding=2, dilation=2, name="conv4_1")
+ tmp = res_block(tmp, filter_num=512, padding=2, dilation=2, name="conv4_2")
+ tmp = res_block(tmp, filter_num=512, padding=2, dilation=2, name="conv4_3")
+ tmp = res_block(tmp, filter_num=512, padding=2, dilation=2, name="conv4_4")
+ tmp = res_block(tmp, filter_num=512, padding=2, dilation=2, name="conv4_5")
+ tmp = res_block(tmp, filter_num=512, padding=2, dilation=2, name="conv4_6")
+
+ tmp = proj_block(
+ tmp, filter_num=1024, padding=4, dilation=4, name="conv5_1")
+ tmp = res_block(tmp, filter_num=1024, padding=4, dilation=4, name="conv5_2")
+ tmp = res_block(tmp, filter_num=1024, padding=4, dilation=4, name="conv5_3")
+ return tmp
+
+
+def pyramis_pooling(input, input_shape):
+ shape = np.ceil(input_shape / 32).astype("int32")
+ h, w = shape
+ pool1 = avg_pool(input, h, w, h, w)
+ pool1_interp = interp(pool1, shape)
+ pool2 = avg_pool(input, h / 2, w / 2, h / 2, w / 2)
+ pool2_interp = interp(pool2, shape)
+ pool3 = avg_pool(input, h / 3, w / 3, h / 3, w / 3)
+ pool3_interp = interp(pool3, shape)
+ pool4 = avg_pool(input, h / 4, w / 4, h / 4, w / 4)
+ pool4_interp = interp(pool4, shape)
+ conv5_3_sum = input + pool4_interp + pool3_interp + pool2_interp + pool1_interp
+ return conv5_3_sum
+
+
+def shared_convs(image):
+ tmp = conv(image, 3, 3, 32, 2, 2, padding='SAME', name="conv1_1_3_3_s2")
+ tmp = bn(tmp, relu=True)
+ tmp = conv(tmp, 3, 3, 32, 1, 1, padding='SAME', name="conv1_2_3_3")
+ tmp = bn(tmp, relu=True)
+ tmp = conv(tmp, 3, 3, 64, 1, 1, padding='SAME', name="conv1_3_3_3")
+ tmp = bn(tmp, relu=True)
+ tmp = max_pool(tmp, 3, 3, 2, 2, padding=[1, 1])
+
+ tmp = proj_block(tmp, filter_num=128, padding=0, name="conv2_1")
+ tmp = res_block(tmp, filter_num=128, padding=1, name="conv2_2")
+ tmp = res_block(tmp, filter_num=128, padding=1, name="conv2_3")
+ tmp = proj_block(tmp, filter_num=256, padding=1, stride=2, name="conv3_1")
+ return tmp
+
+
+def res_block(input, filter_num, padding=0, dilation=None, name=None):
+ tmp = conv(input, 1, 1, filter_num / 4, 1, 1, name=name + "_1_1_reduce")
+ tmp = bn(tmp, relu=True)
+ tmp = zero_padding(tmp, padding=padding)
+ if dilation is None:
+ tmp = conv(tmp, 3, 3, filter_num / 4, 1, 1, name=name + "_3_3")
+ else:
+ tmp = atrous_conv(
+ tmp, 3, 3, filter_num / 4, dilation, name=name + "_3_3")
+ tmp = bn(tmp, relu=True)
+ tmp = conv(tmp, 1, 1, filter_num, 1, 1, name=name + "_1_1_increase")
+ tmp = bn(tmp, relu=False)
+ tmp = input + tmp
+ tmp = fluid.layers.relu(tmp, name=name + "_relu")
+ return tmp
+
+
+def proj_block(input, filter_num, padding=0, dilation=None, stride=1,
+ name=None):
+ proj = conv(
+ input, 1, 1, filter_num, stride, stride, name=name + "_1_1_proj")
+ proj_bn = bn(proj, relu=False)
+
+ tmp = conv(
+ input, 1, 1, filter_num / 4, stride, stride, name=name + "_1_1_reduce")
+ tmp = bn(tmp, relu=True)
+
+ tmp = zero_padding(tmp, padding=padding)
+ if padding == 0:
+ padding = 'SAME'
+ else:
+ padding = 'VALID'
+ if dilation is None:
+ tmp = conv(
+ tmp,
+ 3,
+ 3,
+ filter_num / 4,
+ 1,
+ 1,
+ padding=padding,
+ name=name + "_3_3")
+ else:
+ tmp = atrous_conv(
+ tmp,
+ 3,
+ 3,
+ filter_num / 4,
+ dilation,
+ padding=padding,
+ name=name + "_3_3")
+
+ tmp = bn(tmp, relu=True)
+ tmp = conv(tmp, 1, 1, filter_num, 1, 1, name=name + "_1_1_increase")
+ tmp = bn(tmp, relu=False)
+ tmp = proj_bn + tmp
+ tmp = fluid.layers.relu(tmp, name=name + "_relu")
+ return tmp
+
+
+def sub_net_4(input, input_shape):
+ tmp = interp(input, out_shape=np.ceil(input_shape / 32))
+ tmp = dilation_convs(tmp)
+ tmp = pyramis_pooling(tmp, input_shape)
+ tmp = conv(tmp, 1, 1, 256, 1, 1, name="conv5_4_k1")
+ tmp = bn(tmp, relu=True)
+ tmp = interp(tmp, input_shape / 16)
+ return tmp
+
+
+def sub_net_2(input):
+ tmp = conv(input, 1, 1, 128, 1, 1, name="conv3_1_sub2_proj")
+ tmp = bn(tmp, relu=False)
+ return tmp
+
+
+def sub_net_1(input):
+ tmp = conv(input, 3, 3, 32, 2, 2, padding='SAME', name="conv1_sub1")
+ tmp = bn(tmp, relu=True)
+ tmp = conv(tmp, 3, 3, 32, 2, 2, padding='SAME', name="conv2_sub1")
+ tmp = bn(tmp, relu=True)
+ tmp = conv(tmp, 3, 3, 64, 2, 2, padding='SAME', name="conv3_sub1")
+ tmp = bn(tmp, relu=True)
+ tmp = conv(tmp, 1, 1, 128, 1, 1, name="conv3_sub1_proj")
+ tmp = bn(tmp, relu=False)
+ return tmp
+
+
+def CCF24(sub2_out, sub4_out, input_shape):
+ tmp = zero_padding(sub4_out, padding=2)
+ tmp = atrous_conv(tmp, 3, 3, 128, 2, name="conv_sub4")
+ tmp = bn(tmp, relu=False)
+ tmp = tmp + sub2_out
+ tmp = fluid.layers.relu(tmp)
+ tmp = interp(tmp, input_shape / 8)
+ return tmp
+
+
+def CCF124(sub1_out, sub24_out, input_shape):
+ tmp = zero_padding(sub24_out, padding=2)
+ tmp = atrous_conv(tmp, 3, 3, 128, 2, name="conv_sub2")
+ tmp = bn(tmp, relu=False)
+ tmp = tmp + sub1_out
+ tmp = fluid.layers.relu(tmp)
+ tmp = interp(tmp, input_shape / 4)
+ return tmp
+
+
+def icnet(data, num_classes, input_shape):
+ image_sub1 = data
+ image_sub2 = interp(data, out_shape=input_shape * 0.5)
+
+ s_convs = shared_convs(image_sub2)
+ sub4_out = sub_net_4(s_convs, input_shape)
+ sub2_out = sub_net_2(s_convs)
+ sub1_out = sub_net_1(image_sub1)
+
+ sub24_out = CCF24(sub2_out, sub4_out, input_shape)
+ sub124_out = CCF124(sub1_out, sub24_out, input_shape)
+
+ conv6_cls = conv(
+ sub124_out, 1, 1, num_classes, 1, 1, biased=True, name="conv6_cls")
+ sub4_out = conv(
+ sub4_out, 1, 1, num_classes, 1, 1, biased=True, name="sub4_out")
+ sub24_out = conv(
+ sub24_out, 1, 1, num_classes, 1, 1, biased=True, name="sub24_out")
+
+ return sub4_out, sub24_out, conv6_cls
diff --git a/fluid/icnet/images/icnet.png b/fluid/icnet/images/icnet.png
new file mode 100644
index 0000000000000000000000000000000000000000..f261bb14a85eceac7cd5df282ebc43021b7760d9
Binary files /dev/null and b/fluid/icnet/images/icnet.png differ
diff --git a/fluid/icnet/images/result.png b/fluid/icnet/images/result.png
new file mode 100644
index 0000000000000000000000000000000000000000..b3b0b52ade05943b4a1d741fa4f3a947e8ac28ae
Binary files /dev/null and b/fluid/icnet/images/result.png differ
diff --git a/fluid/icnet/images/train_loss.png b/fluid/icnet/images/train_loss.png
new file mode 100644
index 0000000000000000000000000000000000000000..15011073ae0bd55a9df853934f3329747ee9a426
Binary files /dev/null and b/fluid/icnet/images/train_loss.png differ
diff --git a/fluid/icnet/infer.py b/fluid/icnet/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..63fb3268060248f70462cf914c613c53a1fc1f89
--- /dev/null
+++ b/fluid/icnet/infer.py
@@ -0,0 +1,133 @@
+"""Infer for ICNet model."""
+import cityscape
+import argparse
+import functools
+import sys
+import os
+import cv2
+
+import paddle.fluid as fluid
+import paddle.v2 as paddle
+from icnet import icnet
+from utils import add_arguments, print_arguments, get_feeder_data
+from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
+from paddle.fluid.initializer import init_on_cpu
+import numpy as np
+
+IMG_MEAN = np.array((103.939, 116.779, 123.68), dtype=np.float32)
+parser = argparse.ArgumentParser(description=__doc__)
+add_arg = functools.partial(add_arguments, argparser=parser)
+# yapf: disable
+add_arg('model_path', str, None, "Model path.")
+add_arg('images_list', str, None, "List file with images to be infered.")
+add_arg('images_path', str, None, "The images path.")
+add_arg('out_path', str, "./output", "Output path.")
+add_arg('use_gpu', bool, True, "Whether use GPU to test.")
+# yapf: enable
+
+data_shape = [3, 1024, 2048]
+num_classes = 19
+
+label_colours = [
+ [128, 64, 128],
+ [244, 35, 231],
+ [69, 69, 69]
+ # 0 = road, 1 = sidewalk, 2 = building
+ ,
+ [102, 102, 156],
+ [190, 153, 153],
+ [153, 153, 153]
+ # 3 = wall, 4 = fence, 5 = pole
+ ,
+ [250, 170, 29],
+ [219, 219, 0],
+ [106, 142, 35]
+ # 6 = traffic light, 7 = traffic sign, 8 = vegetation
+ ,
+ [152, 250, 152],
+ [69, 129, 180],
+ [219, 19, 60]
+ # 9 = terrain, 10 = sky, 11 = person
+ ,
+ [255, 0, 0],
+ [0, 0, 142],
+ [0, 0, 69]
+ # 12 = rider, 13 = car, 14 = truck
+ ,
+ [0, 60, 100],
+ [0, 79, 100],
+ [0, 0, 230]
+ # 15 = bus, 16 = train, 17 = motocycle
+ ,
+ [119, 10, 32]
+]
+
+# 18 = bicycle
+
+
+def color(input):
+ """
+ Convert infered result to color image.
+ """
+ result = []
+ for i in input.flatten():
+ result.append(
+ [label_colours[i][2], label_colours[i][1], label_colours[i][0]])
+ result = np.array(result).reshape([input.shape[0], input.shape[1], 3])
+ return result
+
+
+def infer(args):
+ data_shape = cityscape.test_data_shape()
+ num_classes = cityscape.num_classes()
+ # define network
+ images = fluid.layers.data(name='image', shape=data_shape, dtype='float32')
+ _, _, sub124_out = icnet(images, num_classes,
+ np.array(data_shape[1:]).astype("float32"))
+ predict = fluid.layers.resize_bilinear(
+ sub124_out, out_shape=data_shape[1:3])
+ predict = fluid.layers.transpose(predict, perm=[0, 2, 3, 1])
+ predict = fluid.layers.reshape(predict, shape=[-1, num_classes])
+ _, predict = fluid.layers.topk(predict, k=1)
+ predict = fluid.layers.reshape(
+ predict,
+ shape=[data_shape[1], data_shape[2], -1]) # batch_size should be 1
+ inference_program = fluid.default_main_program().clone(for_test=True)
+ # prepare environment
+ place = fluid.CPUPlace()
+ if args.use_gpu:
+ place = fluid.CUDAPlace(0)
+ exe = fluid.Executor(place)
+ exe.run(fluid.default_startup_program())
+ assert os.path.exists(args.model_path)
+ fluid.io.load_params(exe, args.model_path)
+ print "loaded model from: %s" % args.model_path
+ sys.stdout.flush()
+
+ if not os.path.isdir(args.out_path):
+ os.makedirs(args.out_path)
+
+ for line in open(args.images_list):
+ image_file = args.images_path + "/" + line.strip()
+ filename = os.path.basename(image_file)
+ image = paddle.image.load_image(
+ image_file, is_color=True).astype("float32")
+ image -= IMG_MEAN
+ img = paddle.image.to_chw(image)[np.newaxis, :]
+ image_t = fluid.core.LoDTensor()
+ image_t.set(img, place)
+ result = exe.run(inference_program,
+ feed={"image": image_t},
+ fetch_list=[predict])
+ cv2.imwrite(args.out_path + "/" + filename + "_result.png",
+ color(result[0]))
+
+
+def main():
+ args = parser.parse_args()
+ print_arguments(args)
+ infer(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fluid/icnet/train.py b/fluid/icnet/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..298a2113a15614641d573551e67006f9abbe751a
--- /dev/null
+++ b/fluid/icnet/train.py
@@ -0,0 +1,137 @@
+"""Trainer for ICNet model."""
+from icnet import icnet
+import cityscape
+import argparse
+import functools
+import sys
+import time
+import paddle.fluid as fluid
+import numpy as np
+from utils import add_arguments, print_arguments, get_feeder_data
+from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
+from paddle.fluid.initializer import init_on_cpu
+
+parser = argparse.ArgumentParser(description=__doc__)
+add_arg = functools.partial(add_arguments, argparser=parser)
+# yapf: disable
+add_arg('batch_size', int, 16, "Minibatch size.")
+add_arg('checkpoint_path', str, None, "Checkpoint svae path.")
+add_arg('init_model', str, None, "Pretrain model path.")
+add_arg('use_gpu', bool, True, "Whether use GPU to train.")
+add_arg('random_mirror', bool, True, "Whether prepare by random mirror.")
+add_arg('random_scaling', bool, True, "Whether prepare by random scaling.")
+# yapf: enable
+
+LAMBDA1 = 0.16
+LAMBDA2 = 0.4
+LAMBDA3 = 1.0
+LEARNING_RATE = 0.003
+POWER = 0.9
+LOG_PERIOD = 1
+CHECKPOINT_PERIOD = 1000
+TOTAL_STEP = 60000
+
+no_grad_set = []
+
+
+def create_loss(predict, label, mask, num_classes):
+ predict = fluid.layers.transpose(predict, perm=[0, 2, 3, 1])
+ predict = fluid.layers.reshape(predict, shape=[-1, num_classes])
+ label = fluid.layers.reshape(label, shape=[-1, 1])
+ predict = fluid.layers.gather(predict, mask)
+ label = fluid.layers.gather(label, mask)
+ label = fluid.layers.cast(label, dtype="int64")
+ loss = fluid.layers.softmax_with_cross_entropy(predict, label)
+ no_grad_set.append(label.name)
+ return fluid.layers.reduce_mean(loss)
+
+
+def poly_decay():
+ global_step = _decay_step_counter()
+ with init_on_cpu():
+ decayed_lr = LEARNING_RATE * (fluid.layers.pow(
+ (1 - global_step / TOTAL_STEP), POWER))
+ return decayed_lr
+
+
+def train(args):
+ data_shape = cityscape.train_data_shape()
+ num_classes = cityscape.num_classes()
+ # define network
+ images = fluid.layers.data(name='image', shape=data_shape, dtype='float32')
+ label_sub1 = fluid.layers.data(name='label_sub1', shape=[1], dtype='int32')
+ label_sub2 = fluid.layers.data(name='label_sub2', shape=[1], dtype='int32')
+ label_sub4 = fluid.layers.data(name='label_sub4', shape=[1], dtype='int32')
+ mask_sub1 = fluid.layers.data(name='mask_sub1', shape=[-1], dtype='int32')
+ mask_sub2 = fluid.layers.data(name='mask_sub2', shape=[-1], dtype='int32')
+ mask_sub4 = fluid.layers.data(name='mask_sub4', shape=[-1], dtype='int32')
+
+ sub4_out, sub24_out, sub124_out = icnet(
+ images, num_classes, np.array(data_shape[1:]).astype("float32"))
+ loss_sub4 = create_loss(sub4_out, label_sub4, mask_sub4, num_classes)
+ loss_sub24 = create_loss(sub24_out, label_sub2, mask_sub2, num_classes)
+ loss_sub124 = create_loss(sub124_out, label_sub1, mask_sub1, num_classes)
+ reduced_loss = LAMBDA1 * loss_sub4 + LAMBDA2 * loss_sub24 + LAMBDA3 * loss_sub124
+
+ regularizer = fluid.regularizer.L2Decay(0.0001)
+ optimizer = fluid.optimizer.Momentum(
+ learning_rate=poly_decay(), momentum=0.9, regularization=regularizer)
+ _, params_grads = optimizer.minimize(reduced_loss, no_grad_set=no_grad_set)
+
+ # prepare environment
+ place = fluid.CPUPlace()
+ if args.use_gpu:
+ place = fluid.CUDAPlace(0)
+ exe = fluid.Executor(place)
+ exe.run(fluid.default_startup_program())
+
+ if args.init_model is not None:
+ print "load model from: %s" % args.init_model
+ sys.stdout.flush()
+ fluid.io.load_params(exe, args.init_model)
+
+ iter_id = 0
+ t_loss = 0.
+ sub4_loss = 0.
+ sub24_loss = 0.
+ sub124_loss = 0.
+ train_reader = cityscape.train(
+ args.batch_size, flip=args.random_mirror, scaling=args.random_scaling)
+ while True:
+ # train a pass
+ for data in train_reader():
+ if iter_id > TOTAL_STEP:
+ return
+ iter_id += 1
+ results = exe.run(
+ feed=get_feeder_data(data, place),
+ fetch_list=[reduced_loss, loss_sub4, loss_sub24, loss_sub124])
+ t_loss += results[0]
+ sub4_loss += results[1]
+ sub24_loss += results[2]
+ sub124_loss += results[3]
+ # training log
+ if iter_id % LOG_PERIOD == 0:
+ print "Iter[%d]; train loss: %.3f; sub4_loss: %.3f; sub24_loss: %.3f; sub124_loss: %.3f" % (
+ iter_id, t_loss / LOG_PERIOD, sub4_loss / LOG_PERIOD,
+ sub24_loss / LOG_PERIOD, sub124_loss / LOG_PERIOD)
+ t_loss = 0.
+ sub4_loss = 0.
+ sub24_loss = 0.
+ sub124_loss = 0.
+ sys.stdout.flush()
+
+ if iter_id % CHECKPOINT_PERIOD == 0:
+ dir_name = args.checkpoint_path + "/" + str(iter_id)
+ fluid.io.save_persistables(exe, dirname=dir_name)
+ print "Saved checkpoint: %s" % (dir_name)
+
+
+def main():
+ args = parser.parse_args()
+ print_arguments(args)
+ train(args)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/fluid/icnet/utils.py b/fluid/icnet/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..699841d65f16ffd0dfae0d27e33c2ec52479826e
--- /dev/null
+++ b/fluid/icnet/utils.py
@@ -0,0 +1,114 @@
+"""Contains common utility functions."""
+# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import distutils.util
+import numpy as np
+from paddle.fluid import core
+
+
+def print_arguments(args):
+ """Print argparse's arguments.
+
+ Usage:
+
+ .. code-block:: python
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("name", default="Jonh", type=str, help="User name.")
+ args = parser.parse_args()
+ print_arguments(args)
+
+ :param args: Input argparse.Namespace for printing.
+ :type args: argparse.Namespace
+ """
+ print("----------- Configuration Arguments -----------")
+ for arg, value in sorted(vars(args).iteritems()):
+ print("%s: %s" % (arg, value))
+ print("------------------------------------------------")
+
+
+def add_arguments(argname, type, default, help, argparser, **kwargs):
+ """Add argparse's argument.
+
+ Usage:
+
+ .. code-block:: python
+
+ parser = argparse.ArgumentParser()
+ add_argument("name", str, "Jonh", "User name.", parser)
+ args = parser.parse_args()
+ """
+ type = distutils.util.strtobool if type == bool else type
+ argparser.add_argument(
+ "--" + argname,
+ default=default,
+ type=type,
+ help=help + ' Default: %(default)s.',
+ **kwargs)
+
+
+def to_lodtensor(data, place):
+ seq_lens = [len(seq) for seq in data]
+ cur_len = 0
+ lod = [cur_len]
+ for l in seq_lens:
+ cur_len += l
+ lod.append(cur_len)
+ flattened_data = np.concatenate(data, axis=0).astype("int32")
+ flattened_data = flattened_data.reshape([len(flattened_data), 1])
+ res = core.LoDTensor()
+ res.set(flattened_data, place)
+ res.set_lod([lod])
+ return res
+
+
+def get_feeder_data(data, place, for_test=False):
+ feed_dict = {}
+ image_t = core.LoDTensor()
+ image_t.set(data[0], place)
+ feed_dict["image"] = image_t
+
+ if not for_test:
+ labels_sub1_t = core.LoDTensor()
+ labels_sub2_t = core.LoDTensor()
+ labels_sub4_t = core.LoDTensor()
+ mask_sub1_t = core.LoDTensor()
+ mask_sub2_t = core.LoDTensor()
+ mask_sub4_t = core.LoDTensor()
+
+ labels_sub1_t.set(data[1], place)
+ labels_sub2_t.set(data[3], place)
+ mask_sub1_t.set(data[2], place)
+ mask_sub2_t.set(data[4], place)
+ labels_sub4_t.set(data[5], place)
+ mask_sub4_t.set(data[6], place)
+ feed_dict["label_sub1"] = labels_sub1_t
+ feed_dict["label_sub2"] = labels_sub2_t
+ feed_dict["mask_sub1"] = mask_sub1_t
+ feed_dict["mask_sub2"] = mask_sub2_t
+ feed_dict["label_sub4"] = labels_sub4_t
+ feed_dict["mask_sub4"] = mask_sub4_t
+ else:
+ label_t = core.LoDTensor()
+ mask_t = core.LoDTensor()
+ label_t.set(data[1], place)
+ mask_t.set(data[2], place)
+ feed_dict["label"] = label_t
+ feed_dict["mask"] = mask_t
+
+ return feed_dict