@@ -21,6 +23,9 @@ PaddleServing在框架中具有一些预定义的计算节点。 一种非常常
 
 ``` python
 import paddle_serving_server as serving
+from paddle_serving_server import OpMaker
+from paddle_serving_server import OpSeqMaker
+
 op_maker = serving.OpMaker()
 read_op = op_maker.create('general_reader')
 general_infer_op = op_maker.create('general_infer')
@@ -32,18 +37,54 @@ op_seq_maker.add_op(general_infer_op)
 op_seq_maker.add_op(general_response_op)
 ```
 
+对于简单的串联逻辑,我们将其简化为`Sequence`,使用`OpSeqMaker`进行构建。用户可以不指定每个节点的前继,默认按加入`OpSeqMaker`的顺序来确定前继。
+
 由于该代码在大多数情况下都会被使用,并且用户不必更改代码,因此PaddleServing会发布一个易于使用的启动命令来启动服务。 示例如下:
 
 ``` python
 python -m paddle_serving_server.serve --model uci_housing_model --thread 10 --port 9292
 ```
 
+### 包含多个输入的节点
+
+在[Paddle Serving中的集成预测](MODEL_ENSEMBLE_IN_PADDLE_SERVING_CN.md)文档中给出了一个包含多个输入节点的样例,示意图和代码如下。
+
+
+
+
+
+```python
+from paddle_serving_server import OpMaker
+from paddle_serving_server import OpGraphMaker
+from paddle_serving_server import Server
+
+op_maker = OpMaker()
+read_op = op_maker.create('general_reader')
+cnn_infer_op = op_maker.create(
+    'general_infer', engine_name='cnn', inputs=[read_op])
+bow_infer_op = op_maker.create(
+    'general_infer', engine_name='bow', inputs=[read_op])
+response_op = op_maker.create(
+    'general_response', inputs=[cnn_infer_op, bow_infer_op])
+
+op_graph_maker = OpGraphMaker()
+op_graph_maker.add_op(read_op)
+op_graph_maker.add_op(cnn_infer_op)
+op_graph_maker.add_op(bow_infer_op)
+op_graph_maker.add_op(response_op)
+```
+
+对于含有多输入节点的计算图,需要使用`OpGraphMaker`来构建,同时必须给出每个节点的前继。
+
 ## 更多示例
 
 如果用户将稀疏特征作为输入,并且模型将对每个特征进行嵌入查找,则我们可以进行分布式嵌入查找操作,该操作不在Paddle训练计算图中。 示例如下:
 
 ``` python
 import paddle_serving_server as serving
+from paddle_serving_server import OpMaker
+from paddle_serving_server import OpSeqMaker
+
 op_maker = serving.OpMaker()
 read_op = op_maker.create('general_reader')
 dist_kv_op = op_maker.create('general_dist_kv')
diff --git a/doc/complex_dag.png b/doc/complex_dag.png
new file mode 100644
index 0000000000000000000000000000000000000000..4e844d9fc3915579ec44bb981e9e2bfc3e4f7675
Binary files /dev/null and b/doc/complex_dag.png differ
diff --git a/doc/model_ensemble_example.png b/doc/model_ensemble_example.png
new file mode 100644
index 0000000000000000000000000000000000000000..823e91ee9ea6e2b10c3bd2c0ca119f088582c685
Binary files /dev/null and b/doc/model_ensemble_example.png differ
diff --git a/python/examples/faster_rcnn_model/000000570688.jpg b/python/examples/faster_rcnn_model/000000570688.jpg
new file mode 100755
index 0000000000000000000000000000000000000000..cb304bd56c4010c08611a30dcca58ea9140cea54
Binary files /dev/null and b/python/examples/faster_rcnn_model/000000570688.jpg differ
diff --git a/python/examples/faster_rcnn_model/000000570688_bbox.jpg b/python/examples/faster_rcnn_model/000000570688_bbox.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..61bc11c02c92b92cffac91a6c3533a90a45c4e14
Binary files /dev/null and b/python/examples/faster_rcnn_model/000000570688_bbox.jpg differ
diff --git a/python/examples/faster_rcnn_model/README.md b/python/examples/faster_rcnn_model/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..aca6e5183daf9a46096587fab8276b4e7346f746
--- /dev/null
+++ b/python/examples/faster_rcnn_model/README.md
@@ -0,0 +1,70 @@
+# Faster RCNN model on Paddle Serving
+
+([简体中文](./README_CN.md)|English)
+
+This article requires [Paddle Detection](https://github.com/PaddlePaddle/PaddleDetection) trained models and configuration files. If users want to quickly deploy on Paddle Serving, please read the Chapter 2 directly.
+
+## 1. Train an object detection model
+
+Users can read [Paddle Detection Getting Started](https://github.com/PaddlePaddle/PaddleDetection/blob/release/0.2/docs/tutorials/GETTING_STARTED_cn.md) to understand the background of Paddle Detection. The purpose of PaddleDetection is to provide a rich and easy-to-use object detection model for industry and academia. Not only is it superior in performance and easy to deploy, but it can also flexibly meet the needs of algorithm research.
+
+### Environmental requirements
+
+CPU version: No special requirements
+
+GPU version: CUDA 9.0 and above
+
+```
+git clone https://github.com/PaddlePaddle/PaddleDetection
+cd PaddleDetection
+```
+Next, you can train the faster rcnn model
+```
+python tools/train.py -c configs/faster_rcnn_r50_1x.yml
+```
+The time for training the model depends on the situation and is related to the computing power of the training equipment and the number of iterations.
+In the training process, `faster_rcnn_r50_1x.yml` defines the snapshot of the saved model. After the final training, the model with the best effect will be saved as `best_model.pdmodel`, which is a compressed PaddleDetection Exclusive model files.
+
+**If we want the model to be used by Paddle Serving, we must do export_model.**
+
+Output model
+```
+python export_model.py
+```
+## 2. Start the model and predict
+If users do not use the Paddle Detection project to train models, we are here to provide you with sample model downloads. If you trained the model with Paddle Detection, you can skip the ** Download Model ** section.
+
+### Download model
+```
+wget https://paddle-serving.bj.bcebos.com/pddet_demo/faster_rcnn_model.tar.gz
+wget https://paddle-serving.bj.bcebos.com/pddet_demo/paddle_serving_app-0.0.1-py2-none-any.whl
+wget https://paddle-serving.bj.bcebos.com/pddet_demo/infer_cfg.yml
+tar xf faster_rcnn_model.tar.gz
+mv faster_rcnn_model/pddet *.
+```
+
+### Start the service
+```
+GLOG_v = 2 python -m paddle_serving_server_gpu.serve --model pddet_serving_model --port 9494 --gpu_id 0
+```
+
+### Perform prediction
+```
+python test_client.py --config_path = infer_cfg.yml --infer_img = 000000570688.jpg --dump_result --visualize
+```
+
+## 3. Result analysis
+
+    
+
+    
+
+This is the input picture
+  
+
+    
+
+    
+
+  
+This is the picture after adding bbox. You can see that the client has done post-processing for the picture. In addition, the output/bbox.json also has the number and coordinate information of each box.
diff --git a/python/examples/faster_rcnn_model/README_CN.md b/python/examples/faster_rcnn_model/README_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..a1ac36ff93f5d75a4d8874b89f3cb1509589c4d0
--- /dev/null
+++ b/python/examples/faster_rcnn_model/README_CN.md
@@ -0,0 +1,70 @@
+# Faster RCNN模型
+
+(简体中文|[English](./README.md))
+
+本文需要[Paddle Detection](https://github.com/PaddlePaddle/PaddleDetection)训练的模型和配置文件。如果用户想要快速部署在Paddle Serving上,请直接阅读第二章节。
+
+## 1. 训练物体检测模型
+
+用户可以阅读 [Paddle Detection入门使用](https://github.com/PaddlePaddle/PaddleDetection/blob/release/0.2/docs/tutorials/GETTING_STARTED_cn.md)来了解Paddle Detection的背景。PaddleDetection的目的是为工业界和学术界提供丰富、易用的目标检测模型。不仅性能优越、易于部署,而且能够灵活的满足算法研究的需求。
+
+### 环境要求
+
+CPU版: 没有特别要求
+
+GPU版: CUDA 9.0及以上
+
+```
+git clone https://github.com/PaddlePaddle/PaddleDetection
+cd PaddleDetection
+```
+接下来可以训练faster rcnn模型
+```
+python tools/train.py -c configs/faster_rcnn_r50_1x.yml
+```
+训练模型的时间视情况而定,与训练的设备算力和迭代轮数相关。
+在训练的过程中,`faster_rcnn_r50_1x.yml`当中定义了保存模型的`snapshot`,在最终训练完成后,效果最好的模型,会被保存为`best_model.pdmodel`,这是一个经过压缩的PaddleDetection的专属模型文件。
+
+**如果我们要让模型可被Paddle Serving所使用,必须做export_model。**
+
+输出模型
+```
+python export_model.py
+```
+
+## 2. 启动模型并预测
+如果用户没有用Paddle Detection项目训练模型,我们也在此为您提供示例模型下载。如果您用Paddle Detection训练了模型,可以跳过 **下载模型** 部分。
+
+### 下载模型
+```
+wget https://paddle-serving.bj.bcebos.com/pddet_demo/faster_rcnn_model.tar.gz
+wget https://paddle-serving.bj.bcebos.com/pddet_demo/paddle_serving_app-0.0.1-py2-none-any.whl
+wget https://paddle-serving.bj.bcebos.com/pddet_demo/infer_cfg.yml
+tar xf faster_rcnn_model.tar.gz
+mv faster_rcnn_model/pddet* .
+```
+
+### 启动服务
+```
+GLOG_v=2 python -m paddle_serving_server_gpu.serve --model pddet_serving_model --port 9494 --gpu_id 0
+```
+
+### 执行预测
+```
+python test_client.py --config_path=infer_cfg.yml --infer_img=000000570688.jpg --dump_result --visualize
+```
+
+## 3. 结果分析
+
+    
+
+    
+
+这是输入图片
+  
+
+    
+
+    
+
+这是实现添加了bbox之后的图片,可以看到客户端已经为图片做好了后处理,此外在output/bbox.json也有各个框的编号和坐标信息。
diff --git a/python/examples/faster_rcnn_model/test_client.py b/python/examples/faster_rcnn_model/test_client.py
new file mode 100755
index 0000000000000000000000000000000000000000..ae2e5b8f6e961d965555d8f268f38be14c0263d0
--- /dev/null
+++ b/python/examples/faster_rcnn_model/test_client.py
@@ -0,0 +1,33 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle_serving_client import Client
+import sys
+import os
+import time
+from paddle_serving_app.reader.pddet import Detection
+import numpy as np
+
+py_version = sys.version_info[0]
+
+feed_var_names = ['image', 'im_shape', 'im_info']
+fetch_var_names = ['multiclass_nms']
+pddet = Detection(config_path=sys.argv[2], output_dir="./output")
+feed_dict = pddet.preprocess(feed_var_names, sys.argv[3])
+client = Client()
+client.load_client_config(sys.argv[1])
+client.connect(['127.0.0.1:9494'])
+fetch_map = client.predict(feed=feed_dict, fetch=fetch_var_names)
+outs = fetch_map.values()
+pddet.postprocess(fetch_map, fetch_var_names)
diff --git a/python/examples/imdb/test_ensemble_client.py b/python/examples/imdb/test_ensemble_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cafb3389fff5a25103bcb2b3a867b73b35b9e8e
--- /dev/null
+++ b/python/examples/imdb/test_ensemble_client.py
@@ -0,0 +1,42 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# pylint: disable=doc-string-missing
+
+from paddle_serving_client import Client
+from imdb_reader import IMDBDataset
+
+client = Client()
+# If you have more than one model, make sure that the input
+# and output of more than one model are the same.
+client.load_client_config('imdb_bow_client_conf/serving_client_conf.prototxt')
+client.connect(["127.0.0.1:9393"])
+
+# you can define any english sentence or dataset here
+# This example reuses imdb reader in training, you
+# can define your own data preprocessing easily.
+imdb_dataset = IMDBDataset()
+imdb_dataset.load_resource('imdb.vocab')
+
+for i in range(3):
+    line = 'i am very sad | 0'
+    word_ids, label = imdb_dataset.get_words_and_label(line)
+    feed = {"words": word_ids}
+    fetch = ["acc", "cost", "prediction"]
+    fetch_maps = client.predict(feed=feed, fetch=fetch)
+    if len(fetch_maps) == 1:
+        print("step: {}, res: {}".format(i, fetch_maps['prediction'][0][1]))
+    else:
+        for model, fetch_map in fetch_maps.items():
+            print("step: {}, model: {}, res: {}".format(i, model, fetch_map[
+                'prediction'][0][1]))
diff --git a/python/examples/imdb/test_ensemble_server.py b/python/examples/imdb/test_ensemble_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..464288a0a167d8487f787d12c4b44a138da86f88
--- /dev/null
+++ b/python/examples/imdb/test_ensemble_server.py
@@ -0,0 +1,40 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# pylint: disable=doc-string-missing
+
+from paddle_serving_server import OpMaker
+from paddle_serving_server import OpGraphMaker
+from paddle_serving_server import Server
+
+op_maker = OpMaker()
+read_op = op_maker.create('general_reader')
+cnn_infer_op = op_maker.create(
+    'general_infer', engine_name='cnn', inputs=[read_op])
+bow_infer_op = op_maker.create(
+    'general_infer', engine_name='bow', inputs=[read_op])
+response_op = op_maker.create(
+    'general_response', inputs=[cnn_infer_op, bow_infer_op])
+
+op_graph_maker = OpGraphMaker()
+op_graph_maker.add_op(read_op)
+op_graph_maker.add_op(cnn_infer_op)
+op_graph_maker.add_op(bow_infer_op)
+op_graph_maker.add_op(response_op)
+
+server = Server()
+server.set_op_graph(op_graph_maker.get_op_graph())
+model_config = {cnn_infer_op: 'imdb_cnn_model', bow_infer_op: 'imdb_bow_model'}
+server.load_model_config(model_config)
+server.prepare_server(workdir="work_dir1", port=9393, device="cpu")
+server.run_server()
diff --git a/python/paddle_serving_app/reader/pddet/__init__.py b/python/paddle_serving_app/reader/pddet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4e356387b5855cb60606f7bdebe7d8c6d091814
--- /dev/null
+++ b/python/paddle_serving_app/reader/pddet/__init__.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+import argparse
+from .image_tool import Resize, Detection
diff --git a/python/paddle_serving_app/reader/pddet/image_tool.py b/python/paddle_serving_app/reader/pddet/image_tool.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b461bc491a90d25a3259cf6db806beae6dbf593
--- /dev/null
+++ b/python/paddle_serving_app/reader/pddet/image_tool.py
@@ -0,0 +1,620 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import time
+
+import numpy as np
+from PIL import Image, ImageDraw
+import cv2
+import yaml
+import copy
+import argparse
+import logging
+import paddle.fluid as fluid
+import json
+
+FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
+logging.basicConfig(level=logging.INFO, format=FORMAT)
+logger = logging.getLogger(__name__)
+
+precision_map = {
+    'trt_int8': fluid.core.AnalysisConfig.Precision.Int8,
+    'trt_fp32': fluid.core.AnalysisConfig.Precision.Float32,
+    'trt_fp16': fluid.core.AnalysisConfig.Precision.Half
+}
+
+
+class Resize(object):
+    def __init__(self,
+                 target_size,
+                 max_size=0,
+                 interp=cv2.INTER_LINEAR,
+                 use_cv2=True,
+                 image_shape=None):
+        super(Resize, self).__init__()
+        self.target_size = target_size
+        self.max_size = max_size
+        self.interp = interp
+        self.use_cv2 = use_cv2
+        self.image_shape = image_shape
+
+    def __call__(self, im):
+        origin_shape = im.shape[:2]
+        im_c = im.shape[2]
+        if self.max_size != 0:
+            im_size_min = np.min(origin_shape[0:2])
+            im_size_max = np.max(origin_shape[0:2])
+            im_scale = float(self.target_size) / float(im_size_min)
+            if np.round(im_scale * im_size_max) > self.max_size:
+                im_scale = float(self.max_size) / float(im_size_max)
+            im_scale_x = im_scale
+            im_scale_y = im_scale
+            resize_w = int(im_scale_x * float(origin_shape[1]))
+            resize_h = int(im_scale_y * float(origin_shape[0]))
+        else:
+            im_scale_x = float(self.target_size) / float(origin_shape[1])
+            im_scale_y = float(self.target_size) / float(origin_shape[0])
+            resize_w = self.target_size
+            resize_h = self.target_size
+        if self.use_cv2:
+            im = cv2.resize(
+                im,
+                None,
+                None,
+                fx=im_scale_x,
+                fy=im_scale_y,
+                interpolation=self.interp)
+        else:
+            if self.max_size != 0:
+                raise TypeError(
+                    'If you set max_size to cap the maximum size of image,'
+                    'please set use_cv2 to True to resize the image.')
+            im = im.astype('uint8')
+            im = Image.fromarray(im)
+            im = im.resize((int(resize_w), int(resize_h)), self.interp)
+            im = np.array(im)
+        # padding im
+        if self.max_size != 0 and self.image_shape is not None:
+            padding_im = np.zeros(
+                (self.max_size, self.max_size, im_c), dtype=np.float32)
+            im_h, im_w = im.shape[:2]
+            padding_im[:im_h, :im_w, :] = im
+            im = padding_im
+        return im, im_scale_x
+
+
+class Normalize(object):
+    def __init__(self, mean, std, is_scale=True, is_channel_first=False):
+        super(Normalize, self).__init__()
+        self.mean = mean
+        self.std = std
+        self.is_scale = is_scale
+        self.is_channel_first = is_channel_first
+
+    def __call__(self, im):
+        im = im.astype(np.float32, copy=False)
+        if self.is_channel_first:
+            mean = np.array(self.mean)[:, np.newaxis, np.newaxis]
+            std = np.array(self.std)[:, np.newaxis, np.newaxis]
+        else:
+            mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
+            std = np.array(self.std)[np.newaxis, np.newaxis, :]
+        if self.is_scale:
+            im = im / 255.0
+        im -= mean
+        im /= std
+        return im
+
+
+class Permute(object):
+    def __init__(self, to_bgr=False, channel_first=True):
+        self.to_bgr = to_bgr
+        self.channel_first = channel_first
+
+    def __call__(self, im):
+        if self.channel_first:
+            im = im.transpose((2, 0, 1))
+        if self.to_bgr:
+            im = im[[2, 1, 0], :, :]
+        return im.copy()
+
+
+class PadStride(object):
+    def __init__(self, stride=0):
+        assert stride >= 0, "Unsupported stride: {},"
+        " the stride in PadStride must be greater "
+        "or equal to 0".format(stride)
+        self.coarsest_stride = stride
+
+    def __call__(self, im):
+        coarsest_stride = self.coarsest_stride
+        if coarsest_stride == 0:
+            return im
+        im_c, im_h, im_w = im.shape
+        pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
+        pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
+        padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
+        padding_im[:, :im_h, :im_w] = im
+        return padding_im
+
+
+class Detection():
+    def __init__(self, config_path, output_dir):
+        self.config_path = config_path
+        self.if_visualize = True
+        self.if_dump_result = True
+        self.output_dir = output_dir
+
+    def DecodeImage(self, im_path):
+        assert os.path.exists(im_path), "Image path {} can not be found".format(
+            im_path)
+        with open(im_path, 'rb') as f:
+            im = f.read()
+        data = np.frombuffer(im, dtype='uint8')
+        im = cv2.imdecode(data, 1)  # BGR mode, but need RGB mode
+        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+        return im
+
+    def Preprocess(self, img_path, arch, config):
+        img = self.DecodeImage(img_path)
+        orig_shape = img.shape
+        scale = 1.
+        data = []
+        data_config = copy.deepcopy(config)
+        for data_aug_conf in data_config:
+            obj = data_aug_conf.pop('type')
+            preprocess = eval(obj)(**data_aug_conf)
+            if obj == 'Resize':
+                img, scale = preprocess(img)
+            else:
+                img = preprocess(img)
+
+        img = img[np.newaxis, :]  # N, C, H, W
+        data.append(img)
+        extra_info = self.get_extra_info(img, arch, orig_shape, scale)
+        data += extra_info
+        return data
+
+    def expand_boxes(self, boxes, scale):
+        """
+        Expand an array of boxes by a given scale.
+        """
+        w_half = (boxes[:, 2] - boxes[:, 0]) * .5
+        h_half = (boxes[:, 3] - boxes[:, 1]) * .5
+        x_c = (boxes[:, 2] + boxes[:, 0]) * .5
+        y_c = (boxes[:, 3] + boxes[:, 1]) * .5
+
+        w_half *= scale
+        h_half *= scale
+
+        boxes_exp = np.zeros(boxes.shape)
+        boxes_exp[:, 0] = x_c - w_half
+        boxes_exp[:, 2] = x_c + w_half
+        boxes_exp[:, 1] = y_c - h_half
+        boxes_exp[:, 3] = y_c + h_half
+
+        return boxes_exp
+
+    def mask2out(self, results, clsid2catid, resolution, thresh_binarize=0.5):
+        import pycocotools.mask as mask_util
+        scale = (resolution + 2.0) / resolution
+
+        segm_res = []
+
+        for t in results:
+            bboxes = t['bbox'][0]
+            lengths = t['bbox'][1][0]
+            if bboxes.shape == (1, 1) or bboxes is None:
+                continue
+            if len(bboxes.tolist()) == 0:
+                continue
+            masks = t['mask'][0]
+
+            s = 0
+            # for each sample
+            for i in range(len(lengths)):
+                num = lengths[i]
+                im_shape = t['im_shape'][i]
+
+                bbox = bboxes[s:s + num][:, 2:]
+                clsid_scores = bboxes[s:s + num][:, 0:2]
+                mask = masks[s:s + num]
+                s += num
+
+                im_h = int(im_shape[0])
+                im_w = int(im_shape[1])
+
+                expand_bbox = expand_boxes(bbox, scale)
+                expand_bbox = expand_bbox.astype(np.int32)
+
+                padded_mask = np.zeros(
+                    (resolution + 2, resolution + 2), dtype=np.float32)
+
+                for j in range(num):
+                    xmin, ymin, xmax, ymax = expand_bbox[j].tolist()
+                    clsid, score = clsid_scores[j].tolist()
+                    clsid = int(clsid)
+                    padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :]
+
+                    catid = clsid2catid[clsid]
+
+                    w = xmax - xmin + 1
+                    h = ymax - ymin + 1
+                    w = np.maximum(w, 1)
+                    h = np.maximum(h, 1)
+
+                    resized_mask = cv2.resize(padded_mask, (w, h))
+                    resized_mask = np.array(
+                        resized_mask > thresh_binarize, dtype=np.uint8)
+                    im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
+
+                    x0 = min(max(xmin, 0), im_w)
+                    x1 = min(max(xmax + 1, 0), im_w)
+                    y0 = min(max(ymin, 0), im_h)
+                    y1 = min(max(ymax + 1, 0), im_h)
+
+                    im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(
+                        y1 - ymin), (x0 - xmin):(x1 - xmin)]
+                    segm = mask_util.encode(
+                        np.array(
+                            im_mask[:, :, np.newaxis], order='F'))[0]
+                    catid = clsid2catid[clsid]
+                    segm['counts'] = segm['counts'].decode('utf8')
+                    coco_res = {
+                        'category_id': catid,
+                        'segmentation': segm,
+                        'score': score
+                    }
+                    segm_res.append(coco_res)
+        return segm_res
+
+    def draw_bbox(self, image, catid2name, bboxes, threshold, color_list):
+        """
+        draw bbox on image
+        """
+        draw = ImageDraw.Draw(image)
+
+        for dt in np.array(bboxes):
+            catid, bbox, score = dt['category_id'], dt['bbox'], dt['score']
+            if score < threshold:
+                continue
+
+            xmin, ymin, w, h = bbox
+            xmax = xmin + w
+            ymax = ymin + h
+
+            color = tuple(color_list[catid])
+
+            # draw bbox
+            draw.line(
+                [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
+                 (xmin, ymin)],
+                width=2,
+                fill=color)
+
+            # draw label
+            text = "{} {:.2f}".format(catid2name[catid], score)
+            tw, th = draw.textsize(text)
+            draw.rectangle(
+                [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
+            draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
+
+        return image
+
+    def draw_mask(self, image, masks, threshold, color_list, alpha=0.7):
+        """
+        Draw mask on image
+        """
+        mask_color_id = 0
+        w_ratio = .4
+        img_array = np.array(image).astype('float32')
+        for dt in np.array(masks):
+            segm, score = dt['segmentation'], dt['score']
+            if score < threshold:
+                continue
+            import pycocotools.mask as mask_util
+            mask = mask_util.decode(segm) * 255
+            color_mask = color_list[mask_color_id % len(color_list), 0:3]
+            mask_color_id += 1
+            for c in range(3):
+                color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
+            idx = np.nonzero(mask)
+            img_array[idx[0], idx[1], :] *= 1.0 - alpha
+            img_array[idx[0], idx[1], :] += alpha * color_mask
+        return Image.fromarray(img_array.astype('uint8'))
+
+    def get_extra_info(self, im, arch, shape, scale):
+        info = []
+        input_shape = []
+        im_shape = []
+        logger.info('The architecture is {}'.format(arch))
+        if 'YOLO' in arch:
+            im_size = np.array([shape[:2]]).astype('int32')
+            logger.info('Extra info: im_size')
+            info.append(im_size)
+        elif 'SSD' in arch:
+            im_shape = np.array([shape[:2]]).astype('int32')
+            logger.info('Extra info: im_shape')
+            info.append([im_shape])
+        elif 'RetinaNet' in arch:
+            input_shape.extend(im.shape[2:])
+            im_info = np.array([input_shape + [scale]]).astype('float32')
+            logger.info('Extra info: im_info')
+            info.append(im_info)
+        elif 'RCNN' in arch:
+            input_shape.extend(im.shape[2:])
+            im_shape.extend(shape[:2])
+            im_info = np.array([input_shape + [scale]]).astype('float32')
+            im_shape = np.array([im_shape + [1.]]).astype('float32')
+            logger.info('Extra info: im_info, im_shape')
+            info.append(im_info)
+            info.append(im_shape)
+        else:
+            logger.error(
+                "Unsupported arch: {}, expect YOLO, SSD, RetinaNet and RCNN".
+                format(arch))
+        return info
+
+    def offset_to_lengths(self, lod):
+        offset = lod[0]
+        lengths = [offset[i + 1] - offset[i] for i in range(len(offset) - 1)]
+        return [lengths]
+
+    def bbox2out(self, results, clsid2catid, is_bbox_normalized=False):
+        """
+        Args:
+            results: request a dict, should include: `bbox`, `im_id`,
+                     if is_bbox_normalized=True, also need `im_shape`.
+            clsid2catid: class id to category id map of COCO2017 dataset.
+            is_bbox_normalized: whether or not bbox is normalized.
+        """
+        xywh_res = []
+        for t in results:
+            bboxes = t['bbox'][0]
+            lengths = t['bbox'][1][0]
+            if bboxes.shape == (1, 1) or bboxes is None:
+                continue
+
+            k = 0
+            for i in range(len(lengths)):
+                num = lengths[i]
+                for j in range(num):
+                    dt = bboxes[k]
+                    clsid, score, xmin, ymin, xmax, ymax = dt.tolist()
+                    catid = (clsid2catid[int(clsid)])
+
+                    if is_bbox_normalized:
+                        xmin, ymin, xmax, ymax = \
+                            self.clip_bbox([xmin, ymin, xmax, ymax])
+                        w = xmax - xmin
+                        h = ymax - ymin
+                        im_shape = t['im_shape'][0][i].tolist()
+                        im_height, im_width = int(im_shape[0]), int(im_shape[1])
+                        xmin *= im_width
+                        ymin *= im_height
+                        w *= im_width
+                        h *= im_height
+                    else:
+                        w = xmax - xmin + 1
+                        h = ymax - ymin + 1
+
+                    bbox = [xmin, ymin, w, h]
+                    coco_res = {
+                        'category_id': catid,
+                        'bbox': bbox,
+                        'score': score
+                    }
+                    xywh_res.append(coco_res)
+                    k += 1
+        return xywh_res
+
+    def get_bbox_result(self, fetch_map, fetch_name, result, conf, clsid2catid):
+        is_bbox_normalized = True if 'SSD' in conf['arch'] else False
+        output = fetch_map[fetch_name]
+        lod = [fetch_map[fetch_name + '.lod']]
+        lengths = self.offset_to_lengths(lod)
+        np_data = np.array(output)
+        result['bbox'] = (np_data, lengths)
+        result['im_id'] = np.array([[0]])
+
+        bbox_results = self.bbox2out([result], clsid2catid, is_bbox_normalized)
+        return bbox_results
+
+    def mask2out(self, results, clsid2catid, resolution, thresh_binarize=0.5):
+        import pycocotools.mask as mask_util
+        scale = (resolution + 2.0) / resolution
+
+        segm_res = []
+
+        for t in results:
+            bboxes = t['bbox'][0]
+            lengths = t['bbox'][1][0]
+            if bboxes.shape == (1, 1) or bboxes is None:
+                continue
+            if len(bboxes.tolist()) == 0:
+                continue
+            masks = t['mask'][0]
+
+            s = 0
+            # for each sample
+            for i in range(len(lengths)):
+                num = lengths[i]
+                im_shape = t['im_shape'][i]
+
+                bbox = bboxes[s:s + num][:, 2:]
+                clsid_scores = bboxes[s:s + num][:, 0:2]
+                mask = masks[s:s + num]
+                s += num
+
+                im_h = int(im_shape[0])
+                im_w = int(im_shape[1])
+
+                expand_bbox = expand_boxes(bbox, scale)
+                expand_bbox = expand_bbox.astype(np.int32)
+
+                padded_mask = np.zeros(
+                    (resolution + 2, resolution + 2), dtype=np.float32)
+
+                for j in range(num):
+                    xmin, ymin, xmax, ymax = expand_bbox[j].tolist()
+                    clsid, score = clsid_scores[j].tolist()
+                    clsid = int(clsid)
+                    padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :]
+
+                    catid = clsid2catid[clsid]
+
+                    w = xmax - xmin + 1
+                    h = ymax - ymin + 1
+                    w = np.maximum(w, 1)
+                    h = np.maximum(h, 1)
+
+                    resized_mask = cv2.resize(padded_mask, (w, h))
+                    resized_mask = np.array(
+                        resized_mask > thresh_binarize, dtype=np.uint8)
+                    im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
+
+                    x0 = min(max(xmin, 0), im_w)
+                    x1 = min(max(xmax + 1, 0), im_w)
+                    y0 = min(max(ymin, 0), im_h)
+                    y1 = min(max(ymax + 1, 0), im_h)
+
+                    im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(
+                        y1 - ymin), (x0 - xmin):(x1 - xmin)]
+                    segm = mask_util.encode(
+                        np.array(
+                            im_mask[:, :, np.newaxis], order='F'))[0]
+                    catid = clsid2catid[clsid]
+                    segm['counts'] = segm['counts'].decode('utf8')
+                    coco_res = {
+                        'category_id': catid,
+                        'segmentation': segm,
+                        'score': score
+                    }
+                    segm_res.append(coco_res)
+        return segm_res
+
+    def get_mask_result(self, fetch_map, fetch_var_names, result, conf,
+                        clsid2catid):
+        resolution = conf['mask_resolution']
+        bbox_out, mask_out = fetch_map[fetch_var_names]
+        lengths = self.offset_to_lengths(bbox_out.lod())
+        bbox = np.array(bbox_out)
+        mask = np.array(mask_out)
+        result['bbox'] = (bbox, lengths)
+        result['mask'] = (mask, lengths)
+        mask_results = self.mask2out([result], clsid2catid,
+                                     conf['mask_resolution'])
+        return mask_results
+
+    def get_category_info(self, with_background, label_list):
+        if label_list[0] != 'background' and with_background:
+            label_list.insert(0, 'background')
+        if label_list[0] == 'background' and not with_background:
+            label_list = label_list[1:]
+        clsid2catid = {i: i for i in range(len(label_list))}
+        catid2name = {i: name for i, name in enumerate(label_list)}
+        return clsid2catid, catid2name
+
+    def color_map(self, num_classes):
+        color_map = num_classes * [0, 0, 0]
+        for i in range(0, num_classes):
+            j = 0
+            lab = i
+            while lab:
+                color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
+                color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
+                color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
+                j += 1
+                lab >>= 3
+        color_map = np.array(color_map).reshape(-1, 3)
+        return color_map
+
+    def visualize(self,
+                  bbox_results,
+                  catid2name,
+                  num_classes,
+                  mask_results=None):
+        image = Image.open(self.infer_img).convert('RGB')
+        color_list = self.color_map(num_classes)
+        image = self.draw_bbox(image, catid2name, bbox_results, 0.5, color_list)
+        if mask_results is not None:
+            image = self.draw_mask(image, mask_results, 0.5, color_list)
+        image_path = os.path.split(self.infer_img)[-1]
+        if not os.path.exists(self.output_dir):
+            os.makedirs(self.output_dir)
+        out_path = os.path.join(self.output_dir, image_path)
+        image.save(out_path, quality=95)
+        logger.info('Save visualize result to {}'.format(out_path))
+
+    def preprocess(self, feed_var_names, image_file):
+        self.infer_img = image_file
+        config_path = self.config_path
+        res = {}
+        assert config_path is not None, "Config path: {} des not exist!".format(
+            model_path)
+        with open(config_path) as f:
+            conf = yaml.safe_load(f)
+
+        img_data = self.Preprocess(image_file, conf['arch'], conf['Preprocess'])
+        if 'SSD' in conf['arch']:
+            img_data, res['im_shape'] = img_data
+            img_data = [img_data]
+        if len(feed_var_names) != len(img_data):
+            raise ValueError(
+                'the length of feed vars does not equals the length of preprocess of img data, please check your feed dict'
+            )
+
+        def processImg(v):
+            np_data = np.array(v[0])
+            res = np_data
+            return res
+
+        feed_dict = {k: processImg(v) for k, v in zip(feed_var_names, img_data)}
+        return feed_dict
+
+    def postprocess(self, fetch_map, fetch_var_names):
+        config_path = self.config_path
+        res = {}
+        with open(config_path) as f:
+            conf = yaml.safe_load(f)
+        if 'SSD' in conf['arch']:
+            img_data, res['im_shape'] = img_data
+            img_data = [img_data]
+        clsid2catid, catid2name = self.get_category_info(
+            conf['with_background'], conf['label_list'])
+        bbox_result = self.get_bbox_result(fetch_map, fetch_var_names[0], res,
+                                           conf, clsid2catid)
+        mask_result = None
+        if 'mask_resolution' in conf:
+            res['im_shape'] = img_data[-1]
+            mask_result = self.get_mask_result(fetch_map, fetch_var_names, res,
+                                               conf, clsid2catid)
+        if self.if_visualize:
+            if os.path.isdir(self.output_dir) is False:
+                os.mkdir(self.output_dir)
+            self.visualize(bbox_result, catid2name,
+                           len(conf['label_list']), mask_result)
+        if self.if_dump_result:
+            if os.path.isdir(self.output_dir) is False:
+                os.mkdir(self.output_dir)
+            bbox_file = os.path.join(self.output_dir, 'bbox.json')
+            logger.info('dump bbox to {}'.format(bbox_file))
+            with open(bbox_file, 'w') as f:
+                json.dump(bbox_result, f, indent=4)
+            if mask_result is not None:
+                mask_file = os.path.join(flags.output_dir, 'mask.json')
+                logger.info('dump mask to {}'.format(mask_file))
+                with open(mask_file, 'w') as f:
+                    json.dump(mask_result, f, indent=4)
diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py
index 98d233f059a1ad0b588bce5bf3ef831d783c3a44..053062ee508b33e7602dea5a53b4868a662452cd 100644
--- a/python/paddle_serving_client/__init__.py
+++ b/python/paddle_serving_client/__init__.py
@@ -264,28 +264,46 @@ class Client(object):
         if res == -1:
             return None
 
-        result_map_batch = []
-        result_map = {}
-        # result map needs to be a numpy array
-        for i, name in enumerate(fetch_names):
-            if self.fetch_names_to_type_[name] == int_type:
-                result_map[name] = result_batch.get_int64_by_name(name)
-                shape = result_batch.get_shape(name)
-                result_map[name] = np.array(result_map[name])
-                result_map[name].shape = shape
-                if name in self.lod_tensor_set:
-                    result_map["{}.lod".format(name)] = result_batch.get_lod(
-                        name)
-            elif self.fetch_names_to_type_[name] == float_type:
-                result_map[name] = result_batch.get_float_by_name(name)
-                shape = result_batch.get_shape(name)
-                result_map[name] = np.array(result_map[name])
-                result_map[name].shape = shape
-                if name in self.lod_tensor_set:
-                    result_map["{}.lod".format(name)] = result_batch.get_lod(
-                        name)
-
-        return result_map
+        multi_result_map = []
+        model_engine_names = result_batch.get_engine_names()
+        for mi, engine_name in enumerate(model_engine_names):
+            result_map = {}
+            # result map needs to be a numpy array
+            for i, name in enumerate(fetch_names):
+                if self.fetch_names_to_type_[name] == int_type:
+                    result_map[name] = result_batch.get_int64_by_name(mi, name)
+                    shape = result_batch.get_shape(mi, name)
+                    result_map[name] = np.array(result_map[name], dtype='int64')
+                    result_map[name].shape = shape
+                    if name in self.lod_tensor_set:
+                        result_map["{}.lod".format(
+                            name)] = result_batch.get_lod(mi, name)
+                elif self.fetch_names_to_type_[name] == float_type:
+                    result_map[name] = result_batch.get_float_by_name(mi, name)
+                    shape = result_batch.get_shape(mi, name)
+                    result_map[name] = np.array(
+                        result_map[name], dtype='float32')
+                    result_map[name].shape = shape
+                    if name in self.lod_tensor_set:
+                        result_map["{}.lod".format(
+                            name)] = result_batch.get_lod(mi, name)
+            multi_result_map.append(result_map)
+
+        ret = None
+        if len(model_engine_names) == 1:
+            # If only one model result is returned, the format of ret is result_map
+            ret = multi_result_map[0]
+        else:
+            # If multiple model results are returned, the format of ret is {name: result_map}
+            ret = {
+                engine_name: multi_result_map[mi]
+                for mi, engine_name in enumerate(model_engine_names)
+            }
+
+        # When using the A/B test, the tag of variant needs to be returned
+        return ret if not need_variant_tag else [
+            ret, self.result_handle_.variant_tag()
+        ]
 
     def release(self):
         self.client_handle_.destroy_predictor()
diff --git a/python/paddle_serving_client/io/__init__.py b/python/paddle_serving_client/io/__init__.py
index d723795f214e22957bff49f0ddf8fd42086b8a7e..74a6ca871b5c1e32b3c1ecbc6656c95d7c78a399 100644
--- a/python/paddle_serving_client/io/__init__.py
+++ b/python/paddle_serving_client/io/__init__.py
@@ -20,6 +20,7 @@ from paddle.fluid.framework import default_main_program
 from paddle.fluid.framework import Program
 from paddle.fluid import CPUPlace
 from paddle.fluid.io import save_inference_model
+import paddle.fluid as fluid
 from ..proto import general_model_config_pb2 as model_conf
 import os
 
@@ -100,3 +101,20 @@ def save_model(server_model_folder,
     with open("{}/serving_server_conf.stream.prototxt".format(
             server_model_folder), "wb") as fout:
         fout.write(config.SerializeToString())
+
+
+def inference_model_to_serving(infer_model, serving_client, serving_server):
+    place = fluid.CPUPlace()
+    exe = fluid.Executor(place)
+    inference_program, feed_target_names, fetch_targets = \
+            fluid.io.load_inference_model(dirname=infer_model, executor=exe)
+    feed_dict = {
+        x: inference_program.global_block().var(x)
+        for x in feed_target_names
+    }
+    fetch_dict = {x.name: x for x in fetch_targets}
+    save_model(serving_client, serving_server, feed_dict, fetch_dict,
+               inference_program)
+    feed_names = feed_dict.keys()
+    fetch_names = fetch_dict.keys()
+    return feed_names, fetch_names
diff --git a/python/paddle_serving_server/__init__.py b/python/paddle_serving_server/__init__.py
index 68a3b424cd9bda861ddf0b54ac42bf356c3de310..a58fb11ac3ee1fbe5086ae4381f6d6208c0c73ec 100644
--- a/python/paddle_serving_server/__init__.py
+++ b/python/paddle_serving_server/__init__.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+# pylint: disable=doc-string-missing
 
 import os
 from .proto import server_configure_pb2 as server_sdk
@@ -21,6 +22,7 @@ import socket
 import paddle_serving_server as paddle_serving_server
 from .version import serving_server_version
 from contextlib import closing
+import collections
 
 
 class OpMaker(object):
@@ -36,17 +38,35 @@ class OpMaker(object):
             "general_dist_kv_quant_infer": "GeneralDistKVQuantInferOp",
             "general_copy": "GeneralCopyOp"
         }
+        self.node_name_suffix_ = collections.defaultdict(int)
 
-    # currently, inputs and outputs are not used
-    # when we have OpGraphMaker, inputs and outputs are necessary
-    def create(self, name, inputs=[], outputs=[]):
-        if name not in self.op_dict:
-            raise Exception("Op name {} is not supported right now".format(
-                name))
+    def create(self, node_type, engine_name=None, inputs=[], outputs=[]):
+        if node_type not in self.op_dict:
+            raise Exception("Op type {} is not supported right now".format(
+                node_type))
         node = server_sdk.DAGNode()
-        node.name = "{}_op".format(name)
-        node.type = self.op_dict[name]
-        return node
+        # node.name will be used as the infer engine name
+        if engine_name:
+            node.name = engine_name
+        else:
+            node.name = '{}_{}'.format(node_type,
+                                       self.node_name_suffix_[node_type])
+            self.node_name_suffix_[node_type] += 1
+
+        node.type = self.op_dict[node_type]
+        if inputs:
+            for dep_node_str in inputs:
+                dep_node = server_sdk.DAGNode()
+                google.protobuf.text_format.Parse(dep_node_str, dep_node)
+                dep = server_sdk.DAGNodeDependency()
+                dep.name = dep_node.name
+                dep.mode = "RO"
+                node.dependencies.extend([dep])
+        # Because the return value will be used as the key value of the
+        # dict, and the proto object is variable which cannot be hashed,
+        # so it is processed into a string. This has little effect on
+        # overall efficiency.
+        return google.protobuf.text_format.MessageToString(node)
 
 
 class OpSeqMaker(object):
@@ -55,12 +75,25 @@ class OpSeqMaker(object):
         self.workflow.name = "workflow1"
         self.workflow.workflow_type = "Sequence"
 
-    def add_op(self, node):
+    def add_op(self, node_str):
+        node = server_sdk.DAGNode()
+        google.protobuf.text_format.Parse(node_str, node)
+        if len(node.dependencies) > 1:
+            raise Exception(
+                'Set more than one predecessor for op in OpSeqMaker is not allowed.'
+            )
         if len(self.workflow.nodes) >= 1:
-            dep = server_sdk.DAGNodeDependency()
-            dep.name = self.workflow.nodes[-1].name
-            dep.mode = "RO"
-            node.dependencies.extend([dep])
+            if len(node.dependencies) == 0:
+                dep = server_sdk.DAGNodeDependency()
+                dep.name = self.workflow.nodes[-1].name
+                dep.mode = "RO"
+                node.dependencies.extend([dep])
+            elif len(node.dependencies) == 1:
+                if node.dependencies[0].name != self.workflow.nodes[-1].name:
+                    raise Exception(
+                        'You must add op in order in OpSeqMaker. The previous op is {}, but the current op is followed by {}.'.
+                        format(node.dependencies[0].name, self.workflow.nodes[
+                            -1].name))
         self.workflow.nodes.extend([node])
 
     def get_op_sequence(self):
@@ -69,13 +102,30 @@ class OpSeqMaker(object):
         return workflow_conf
 
 
+class OpGraphMaker(object):
+    def __init__(self):
+        self.workflow = server_sdk.Workflow()
+        self.workflow.name = "workflow1"
+        # Currently, SDK only supports "Sequence"
+        self.workflow.workflow_type = "Sequence"
+
+    def add_op(self, node_str):
+        node = server_sdk.DAGNode()
+        google.protobuf.text_format.Parse(node_str, node)
+        self.workflow.nodes.extend([node])
+
+    def get_op_graph(self):
+        workflow_conf = server_sdk.WorkflowConf()
+        workflow_conf.workflows.extend([self.workflow])
+        return workflow_conf
+
+
 class Server(object):
     def __init__(self):
         self.server_handle_ = None
         self.infer_service_conf = None
         self.model_toolkit_conf = None
         self.resource_conf = None
-        self.engine = None
         self.memory_optimization = False
         self.model_conf = None
         self.workflow_fn = "workflow.prototxt"
@@ -94,6 +144,7 @@ class Server(object):
         self.cur_path = os.getcwd()
         self.use_local_bin = False
         self.mkl_flag = False
+        self.model_config_paths = None  # for multi-model in a workflow
 
     def set_max_concurrency(self, concurrency):
         self.max_concurrency = concurrency
@@ -118,6 +169,9 @@ class Server(object):
     def set_op_sequence(self, op_seq):
         self.workflow_conf = op_seq
 
+    def set_op_graph(self, op_graph):
+        self.workflow_conf = op_graph
+
     def set_memory_optimize(self, flag=False):
         self.memory_optimization = flag
 
@@ -126,32 +180,30 @@ class Server(object):
             self.use_local_bin = True
             self.bin_path = os.environ["SERVING_BIN"]
 
-    def _prepare_engine(self, model_config_path, device):
+    def _prepare_engine(self, model_config_paths, device):
         if self.model_toolkit_conf == None:
             self.model_toolkit_conf = server_sdk.ModelToolkitConf()
 
-        if self.engine == None:
-            self.engine = server_sdk.EngineDesc()
-
-        self.model_config_path = model_config_path
-        self.engine.name = "general_model"
-        self.engine.reloadable_meta = model_config_path + "/fluid_time_file"
-        os.system("touch {}".format(self.engine.reloadable_meta))
-        self.engine.reloadable_type = "timestamp_ne"
-        self.engine.runtime_thread_num = 0
-        self.engine.batch_infer_size = 0
-        self.engine.enable_batch_align = 0
-        self.engine.model_data_path = model_config_path
-        self.engine.enable_memory_optimization = self.memory_optimization
-        self.engine.static_optimization = False
-        self.engine.force_update_static_cache = False
-
-        if device == "cpu":
-            self.engine.type = "FLUID_CPU_ANALYSIS_DIR"
-        elif device == "gpu":
-            self.engine.type = "FLUID_GPU_ANALYSIS_DIR"
-
-        self.model_toolkit_conf.engines.extend([self.engine])
+        for engine_name, model_config_path in model_config_paths.items():
+            engine = server_sdk.EngineDesc()
+            engine.name = engine_name
+            engine.reloadable_meta = model_config_path + "/fluid_time_file"
+            os.system("touch {}".format(engine.reloadable_meta))
+            engine.reloadable_type = "timestamp_ne"
+            engine.runtime_thread_num = 0
+            engine.batch_infer_size = 0
+            engine.enable_batch_align = 0
+            engine.model_data_path = model_config_path
+            engine.enable_memory_optimization = self.memory_optimization
+            engine.static_optimization = False
+            engine.force_update_static_cache = False
+
+            if device == "cpu":
+                engine.type = "FLUID_CPU_ANALYSIS_DIR"
+            elif device == "gpu":
+                engine.type = "FLUID_GPU_ANALYSIS_DIR"
+
+            self.model_toolkit_conf.engines.extend([engine])
 
     def _prepare_infer_service(self, port):
         if self.infer_service_conf == None:
@@ -184,10 +236,49 @@ class Server(object):
         with open(filepath, "w") as fout:
             fout.write(str(pb_obj))
 
-    def load_model_config(self, path):
-        self.model_config_path = path
+    def load_model_config(self, model_config_paths):
+        # At present, Serving needs to configure the model path in
+        # the resource.prototxt file to determine the input and output
+        # format of the workflow. To ensure that the input and output
+        # of multiple models are the same.
+        workflow_oi_config_path = None
+        if isinstance(model_config_paths, str):
+            # If there is only one model path, use the default infer_op.
+            # Because there are several infer_op type, we need to find 
+            # it from workflow_conf.
+            default_engine_names = [
+                'general_infer_0', 'general_dist_kv_infer_0',
+                'general_dist_kv_quant_infer_0'
+            ]
+            engine_name = None
+            for node in self.workflow_conf.workflows[0].nodes:
+                if node.name in default_engine_names:
+                    engine_name = node.name
+                    break
+            if engine_name is None:
+                raise Exception(
+                    "You have set the engine_name of Op. Please use the form {op: model_path} to configure model path"
+                )
+            self.model_config_paths = {engine_name: model_config_paths}
+            workflow_oi_config_path = self.model_config_paths[engine_name]
+        elif isinstance(model_config_paths, dict):
+            self.model_config_paths = {}
+            for node_str, path in model_config_paths.items():
+                node = server_sdk.DAGNode()
+                google.protobuf.text_format.Parse(node_str, node)
+                self.model_config_paths[node.name] = path
+            print("You have specified multiple model paths, please ensure "
+                  "that the input and output of multiple models are the same.")
+            workflow_oi_config_path = self.model_config_paths.items()[0][1]
+        else:
+            raise Exception("The type of model_config_paths must be str or "
+                            "dict({op: model_path}), not {}.".format(
+                                type(model_config_paths)))
+
         self.model_conf = m_config.GeneralModelConfig()
-        f = open("{}/serving_server_conf.prototxt".format(path), 'r')
+        f = open(
+            "{}/serving_server_conf.prototxt".format(workflow_oi_config_path),
+            'r')
         self.model_conf = google.protobuf.text_format.Merge(
             str(f.read()), self.model_conf)
         # check config here
@@ -258,7 +349,7 @@ class Server(object):
         if not self.port_is_available(port):
             raise SystemExit("Prot {} is already used".format(port))
         self._prepare_resource(workdir)
-        self._prepare_engine(self.model_config_path, device)
+        self._prepare_engine(self.model_config_paths, device)
         self._prepare_infer_service(port)
         self.port = port
         self.workdir = workdir
diff --git a/python/paddle_serving_server_gpu/__init__.py b/python/paddle_serving_server_gpu/__init__.py
index 0e5a49c4870956557c99fdf8abf08edf47bc4aa6..45e71a383b4fe0e5ca3a5284985b702cd815f18c 100644
--- a/python/paddle_serving_server_gpu/__init__.py
+++ b/python/paddle_serving_server_gpu/__init__.py
@@ -24,6 +24,7 @@ import time
 from .version import serving_server_version
 from contextlib import closing
 import argparse
+import collections
 
 
 def serve_args():
@@ -66,17 +67,35 @@ class OpMaker(object):
             "general_dist_kv_infer": "GeneralDistKVInferOp",
             "general_dist_kv": "GeneralDistKVOp"
         }
+        self.node_name_suffix_ = collections.defaultdict(int)
 
-    # currently, inputs and outputs are not used
-    # when we have OpGraphMaker, inputs and outputs are necessary
-    def create(self, name, inputs=[], outputs=[]):
-        if name not in self.op_dict:
-            raise Exception("Op name {} is not supported right now".format(
-                name))
+    def create(self, node_type, engine_name=None, inputs=[], outputs=[]):
+        if node_type not in self.op_dict:
+            raise Exception("Op type {} is not supported right now".format(
+                node_type))
         node = server_sdk.DAGNode()
-        node.name = "{}_op".format(name)
-        node.type = self.op_dict[name]
-        return node
+        # node.name will be used as the infer engine name
+        if engine_name:
+            node.name = engine_name
+        else:
+            node.name = '{}_{}'.format(node_type,
+                                       self.node_name_suffix_[node_type])
+            self.node_name_suffix_[node_type] += 1
+
+        node.type = self.op_dict[node_type]
+        if inputs:
+            for dep_node_str in inputs:
+                dep_node = server_sdk.DAGNode()
+                google.protobuf.text_format.Parse(dep_node_str, dep_node)
+                dep = server_sdk.DAGNodeDependency()
+                dep.name = dep_node.name
+                dep.mode = "RO"
+                node.dependencies.extend([dep])
+        # Because the return value will be used as the key value of the
+        # dict, and the proto object is variable which cannot be hashed,
+        # so it is processed into a string. This has little effect on
+        # overall efficiency.
+        return google.protobuf.text_format.MessageToString(node)
 
 
 class OpSeqMaker(object):
@@ -85,12 +104,25 @@ class OpSeqMaker(object):
         self.workflow.name = "workflow1"
         self.workflow.workflow_type = "Sequence"
 
-    def add_op(self, node):
+    def add_op(self, node_str):
+        node = server_sdk.DAGNode()
+        google.protobuf.text_format.Parse(node_str, node)
+        if len(node.dependencies) > 1:
+            raise Exception(
+                'Set more than one predecessor for op in OpSeqMaker is not allowed.'
+            )
         if len(self.workflow.nodes) >= 1:
-            dep = server_sdk.DAGNodeDependency()
-            dep.name = self.workflow.nodes[-1].name
-            dep.mode = "RO"
-            node.dependencies.extend([dep])
+            if len(node.dependencies) == 0:
+                dep = server_sdk.DAGNodeDependency()
+                dep.name = self.workflow.nodes[-1].name
+                dep.mode = "RO"
+                node.dependencies.extend([dep])
+            elif len(node.dependencies) == 1:
+                if node.dependencies[0].name != self.workflow.nodes[-1].name:
+                    raise Exception(
+                        'You must add op in order in OpSeqMaker. The previous op is {}, but the current op is followed by {}.'.
+                        format(node.dependencies[0].name, self.workflow.nodes[
+                            -1].name))
         self.workflow.nodes.extend([node])
 
     def get_op_sequence(self):
@@ -99,13 +131,30 @@ class OpSeqMaker(object):
         return workflow_conf
 
 
+class OpGraphMaker(object):
+    def __init__(self):
+        self.workflow = server_sdk.Workflow()
+        self.workflow.name = "workflow1"
+        # Currently, SDK only supports "Sequence"
+        self.workflow.workflow_type = "Sequence"
+
+    def add_op(self, node_str):
+        node = server_sdk.DAGNode()
+        google.protobuf.text_format.Parse(node_str, node)
+        self.workflow.nodes.extend([node])
+
+    def get_op_graph(self):
+        workflow_conf = server_sdk.WorkflowConf()
+        workflow_conf.workflows.extend([self.workflow])
+        return workflow_conf
+
+
 class Server(object):
     def __init__(self):
         self.server_handle_ = None
         self.infer_service_conf = None
         self.model_toolkit_conf = None
         self.resource_conf = None
-        self.engine = None
         self.memory_optimization = False
         self.model_conf = None
         self.workflow_fn = "workflow.prototxt"
@@ -125,6 +174,7 @@ class Server(object):
         self.check_cuda()
         self.use_local_bin = False
         self.gpuid = 0
+        self.model_config_paths = None  # for multi-model in a workflow
 
     def set_max_concurrency(self, concurrency):
         self.max_concurrency = concurrency
@@ -149,6 +199,9 @@ class Server(object):
     def set_op_sequence(self, op_seq):
         self.workflow_conf = op_seq
 
+    def set_op_graph(self, op_graph):
+        self.workflow_conf = op_graph
+
     def set_memory_optimize(self, flag=False):
         self.memory_optimization = flag
 
@@ -167,33 +220,31 @@ class Server(object):
     def set_gpuid(self, gpuid=0):
         self.gpuid = gpuid
 
-    def _prepare_engine(self, model_config_path, device):
+    def _prepare_engine(self, model_config_paths, device):
         if self.model_toolkit_conf == None:
             self.model_toolkit_conf = server_sdk.ModelToolkitConf()
 
-        if self.engine == None:
-            self.engine = server_sdk.EngineDesc()
-
-        self.model_config_path = model_config_path
-        self.engine.name = "general_model"
-        #self.engine.reloadable_meta = model_config_path + "/fluid_time_file"
-        self.engine.reloadable_meta = self.workdir + "/fluid_time_file"
-        os.system("touch {}".format(self.engine.reloadable_meta))
-        self.engine.reloadable_type = "timestamp_ne"
-        self.engine.runtime_thread_num = 0
-        self.engine.batch_infer_size = 0
-        self.engine.enable_batch_align = 0
-        self.engine.model_data_path = model_config_path
-        self.engine.enable_memory_optimization = self.memory_optimization
-        self.engine.static_optimization = False
-        self.engine.force_update_static_cache = False
-
-        if device == "cpu":
-            self.engine.type = "FLUID_CPU_ANALYSIS_DIR"
-        elif device == "gpu":
-            self.engine.type = "FLUID_GPU_ANALYSIS_DIR"
-
-        self.model_toolkit_conf.engines.extend([self.engine])
+        for engine_name, model_config_path in model_config_paths.items():
+            engine = server_sdk.EngineDesc()
+            engine.name = engine_name
+            # engine.reloadable_meta = model_config_path + "/fluid_time_file"
+            engine.reloadable_meta = self.workdir + "/fluid_time_file"
+            os.system("touch {}".format(engine.reloadable_meta))
+            engine.reloadable_type = "timestamp_ne"
+            engine.runtime_thread_num = 0
+            engine.batch_infer_size = 0
+            engine.enable_batch_align = 0
+            engine.model_data_path = model_config_path
+            engine.enable_memory_optimization = self.memory_optimization
+            engine.static_optimization = False
+            engine.force_update_static_cache = False
+
+            if device == "cpu":
+                engine.type = "FLUID_CPU_ANALYSIS_DIR"
+            elif device == "gpu":
+                engine.type = "FLUID_GPU_ANALYSIS_DIR"
+
+            self.model_toolkit_conf.engines.extend([engine])
 
     def _prepare_infer_service(self, port):
         if self.infer_service_conf == None:
@@ -225,10 +276,49 @@ class Server(object):
         with open(filepath, "w") as fout:
             fout.write(str(pb_obj))
 
-    def load_model_config(self, path):
-        self.model_config_path = path
+    def load_model_config(self, model_config_paths):
+        # At present, Serving needs to configure the model path in
+        # the resource.prototxt file to determine the input and output
+        # format of the workflow. To ensure that the input and output
+        # of multiple models are the same.
+        workflow_oi_config_path = None
+        if isinstance(model_config_paths, str):
+            # If there is only one model path, use the default infer_op.
+            # Because there are several infer_op type, we need to find 
+            # it from workflow_conf.
+            default_engine_names = [
+                'general_infer_0', 'general_dist_kv_infer_0',
+                'general_dist_kv_quant_infer_0'
+            ]
+            engine_name = None
+            for node in self.workflow_conf.workflows[0].nodes:
+                if node.name in default_engine_names:
+                    engine_name = node.name
+                    break
+            if engine_name is None:
+                raise Exception(
+                    "You have set the engine_name of Op. Please use the form {op: model_path} to configure model path"
+                )
+            self.model_config_paths = {engine_name: model_config_paths}
+            workflow_oi_config_path = self.model_config_paths[engine_name]
+        elif isinstance(model_config_paths, dict):
+            self.model_config_paths = {}
+            for node_str, path in model_config_paths.items():
+                node = server_sdk.DAGNode()
+                google.protobuf.text_format.Parse(node_str, node)
+                self.model_config_paths[node.name] = path
+            print("You have specified multiple model paths, please ensure "
+                  "that the input and output of multiple models are the same.")
+            workflow_oi_config_path = self.model_config_paths.items()[0][1]
+        else:
+            raise Exception("The type of model_config_paths must be str or "
+                            "dict({op: model_path}), not {}.".format(
+                                type(model_config_paths)))
+
         self.model_conf = m_config.GeneralModelConfig()
-        f = open("{}/serving_server_conf.prototxt".format(path), 'r')
+        f = open(
+            "{}/serving_server_conf.prototxt".format(workflow_oi_config_path),
+            'r')
         self.model_conf = google.protobuf.text_format.Merge(
             str(f.read()), self.model_conf)
         # check config here
@@ -291,7 +381,7 @@ class Server(object):
 
         self.set_port(port)
         self._prepare_resource(workdir)
-        self._prepare_engine(self.model_config_path, device)
+        self._prepare_engine(self.model_config_paths, device)
         self._prepare_infer_service(port)
         self.workdir = workdir
 
diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py
index 6841220f9f4e52a23bc7b0a0176c58672fc4b675..eb1ecfd8faaf34a6bf2955af46d5a8cf09085ad7 100644
--- a/python/paddle_serving_server_gpu/web_service.py
+++ b/python/paddle_serving_server_gpu/web_service.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+# pylint: disable=doc-string-missing
 
 from flask import Flask, request, abort
 from contextlib import closing
diff --git a/python/setup.py.app.in b/python/setup.py.app.in
index 13e71b22cdc5eb719c17af974dd2150710133491..3c0f8e065a5072919d808ba1da67f5c37eee0594 100644
--- a/python/setup.py.app.in
+++ b/python/setup.py.app.in
@@ -47,7 +47,8 @@ REQUIRED_PACKAGES = [
 
 packages=['paddle_serving_app',
           'paddle_serving_app.reader',
-	  'paddle_serving_app.utils']
+	  'paddle_serving_app.utils',
+      'paddle_serving_app.reader.pddet']
 
 package_data={}
 package_dir={'paddle_serving_app':
@@ -55,7 +56,9 @@ package_dir={'paddle_serving_app':
              'paddle_serving_app.reader':
              '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader',
 	     'paddle_serving_app.utils':
-	     '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/utils',}
+	     '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/utils',
+          'paddle_serving_app.reader.pddet':
+           '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving_app/reader/pddet',}
 
 setup(
     name='paddle-serving-app',
diff --git a/tools/serving_build.sh b/tools/serving_build.sh
index e4bf6ece3a9df1808b9190e9e77d8d2e8aba62c0..1e47b8f4fe26c689b5d6680c1478740201b335b9 100644
--- a/tools/serving_build.sh
+++ b/tools/serving_build.sh
@@ -323,6 +323,9 @@ function python_test_bert() {
             echo "bert RPC inference pass"
             ;;
         *)
+            echo "error type"
+            exit 1
+            ;;
     esac
     echo "test bert $TYPE finished as expected."
     unset SERVING_BIN
@@ -357,6 +360,9 @@ function python_test_imdb() {
             echo "imdb ignore GPU test"
             ;;
         *)
+            echo "error type"
+            exit 1
+            ;;
     esac
     echo "test imdb $TYPE finished as expected."
     unset SERVING_BIN
@@ -389,6 +395,9 @@ function python_test_lac() {
             echo "lac ignore GPU test"
             ;;
         *)
+            echo "error type"
+            exit 1
+            ;;
     esac
     echo "test lac $TYPE finished as expected."
     unset SERVING_BIN
@@ -408,6 +417,248 @@ function python_run_test() {
     cd ../.. # pwd: /Serving
 }
 
+function monitor_test() {
+    local TYPE=$1 # pwd: /Serving
+    mkdir _monitor_test && cd _monitor_test # pwd: /Serving/_monitor_test
+    case $TYPE in
+        CPU):
+            pip install pyftpdlib
+            mkdir remote_path
+            mkdir local_path
+            cd remote_path # pwd: /Serving/_monitor_test/remote_path
+            check_cmd "python -m pyftpdlib -p 8000 &>/dev/null &"
+            cd .. # pwd: /Serving/_monitor_test
+
+            # type: ftp
+            # remote_path: /
+            # remote_model_name: uci_housing.tar.gz
+            # local_tmp_path: ___tmp
+            # local_path: local_path
+            cd remote_path # pwd: /Serving/_monitor_test/remote_path
+            wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz
+            touch donefile
+            cd ..  # pwd: /Serving/_monitor_test
+            mkdir -p local_path/uci_housing_model
+            python -m paddle_serving_server.monitor \
+                    --type='ftp' --ftp_host='127.0.0.1' --ftp_port='8000' \
+                    --remote_path='/' --remote_model_name='uci_housing.tar.gz' \
+                    --remote_donefile_name='donefile' --local_path='local_path' \
+                    --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \
+                    --local_tmp_path='___tmp' --unpacked_filename='uci_housing_model' \
+                    --interval='1' >/dev/null &
+            sleep 10
+            if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then
+                echo "local_path/uci_housing_model/fluid_time_file not exist."
+                exit 1
+            fi
+            ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill
+            rm -rf remote_path/*
+            rm -rf local_path/*
+
+            # type: ftp
+            # remote_path: /tmp_dir
+            # remote_model_name: uci_housing_model
+            # local_tmp_path: ___tmp
+            # local_path: local_path
+            mkdir -p remote_path/tmp_dir && cd remote_path/tmp_dir # pwd: /Serving/_monitor_test/remote_path/tmp_dir
+            wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz
+            tar -xzf uci_housing.tar.gz
+            touch donefile
+            cd ../.. # pwd: /Serving/_monitor_test
+            mkdir -p local_path/uci_housing_model
+            python -m paddle_serving_server.monitor \
+                    --type='ftp' --ftp_host='127.0.0.1' --ftp_port='8000' \
+                    --remote_path='/tmp_dir' --remote_model_name='uci_housing_model' \
+                    --remote_donefile_name='donefile' --local_path='local_path' \
+                    --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \
+                    --local_tmp_path='___tmp' --interval='1' >/dev/null &
+            sleep 10
+            if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then
+                echo "local_path/uci_housing_model/fluid_time_file not exist."
+                exit 1
+            fi
+            ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill
+            rm -rf remote_path/*
+            rm -rf local_path/*
+
+            # type: general
+            # remote_path: /
+            # remote_model_name: uci_housing.tar.gz
+            # local_tmp_path: ___tmp
+            # local_path: local_path
+            cd remote_path # pwd: /Serving/_monitor_test/remote_path
+            wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz
+            touch donefile
+            cd ..  # pwd: /Serving/_monitor_test
+            mkdir -p local_path/uci_housing_model
+            python -m paddle_serving_server.monitor \
+                    --type='general' --general_host='ftp://127.0.0.1:8000' \
+                    --remote_path='/' --remote_model_name='uci_housing.tar.gz' \
+                    --remote_donefile_name='donefile' --local_path='local_path' \
+                    --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \
+                    --local_tmp_path='___tmp' --unpacked_filename='uci_housing_model' \
+                    --interval='1' >/dev/null &
+            sleep 10
+            if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then
+                echo "local_path/uci_housing_model/fluid_time_file not exist."
+                exit 1
+            fi
+            ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill
+            rm -rf remote_path/*
+            rm -rf local_path/*
+
+            # type: general
+            # remote_path: /tmp_dir
+            # remote_model_name: uci_housing_model
+            # local_tmp_path: ___tmp
+            # local_path: local_path
+            mkdir -p remote_path/tmp_dir && cd remote_path/tmp_dir # pwd: /Serving/_monitor_test/remote_path/tmp_dir
+            wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz
+            tar -xzf uci_housing.tar.gz
+            touch donefile
+            cd ../.. # pwd: /Serving/_monitor_test
+            mkdir -p local_path/uci_housing_model
+            python -m paddle_serving_server.monitor \
+                    --type='general' --general_host='ftp://127.0.0.1:8000' \
+                    --remote_path='/tmp_dir' --remote_model_name='uci_housing_model' \
+                    --remote_donefile_name='donefile' --local_path='local_path' \
+                    --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \
+                    --local_tmp_path='___tmp' --interval='1' >/dev/null &
+            sleep 10
+            if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then
+                echo "local_path/uci_housing_model/fluid_time_file not exist."
+                exit 1
+            fi
+            ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill
+            rm -rf remote_path/*
+            rm -rf local_path/*
+
+            ps -ef | grep "pyftpdlib" | grep -v grep | awk '{print $2}' | xargs kill
+            ;;
+        GPU):
+            pip install pyftpdlib
+            mkdir remote_path
+            mkdir local_path
+            cd remote_path # pwd: /Serving/_monitor_test/remote_path
+            check_cmd "python -m pyftpdlib -p 8000 &>/dev/null &"
+            cd .. # pwd: /Serving/_monitor_test
+
+            # type: ftp
+            # remote_path: /
+            # remote_model_name: uci_housing.tar.gz
+            # local_tmp_path: ___tmp
+            # local_path: local_path
+            cd remote_path # pwd: /Serving/_monitor_test/remote_path
+            wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz
+            touch donefile
+            cd ..  # pwd: /Serving/_monitor_test
+            mkdir -p local_path/uci_housing_model
+            python -m paddle_serving_server_gpu.monitor \
+                    --type='ftp' --ftp_host='127.0.0.1' --ftp_port='8000' \
+                    --remote_path='/' --remote_model_name='uci_housing.tar.gz' \
+                    --remote_donefile_name='donefile' --local_path='local_path' \
+                    --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \
+                    --local_tmp_path='___tmp' --unpacked_filename='uci_housing_model' \
+                    --interval='1' >/dev/null &
+            sleep 10
+            if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then
+                echo "local_path/uci_housing_model/fluid_time_file not exist."
+                exit 1
+            fi
+            ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill
+            rm -rf remote_path/*
+            rm -rf local_path/*
+
+            # type: ftp
+            # remote_path: /tmp_dir
+            # remote_model_name: uci_housing_model
+            # local_tmp_path: ___tmp
+            # local_path: local_path
+            mkdir -p remote_path/tmp_dir && cd remote_path/tmp_dir # pwd: /Serving/_monitor_test/remote_path/tmp_dir
+            wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz
+            tar -xzf uci_housing.tar.gz
+            touch donefile
+            cd ../.. # pwd: /Serving/_monitor_test
+            mkdir -p local_path/uci_housing_model
+            python -m paddle_serving_server_gpu.monitor \
+                    --type='ftp' --ftp_host='127.0.0.1' --ftp_port='8000' \
+                    --remote_path='/tmp_dir' --remote_model_name='uci_housing_model' \
+                    --remote_donefile_name='donefile' --local_path='local_path' \
+                    --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \
+                    --local_tmp_path='___tmp' --interval='1' >/dev/null &
+            sleep 10
+            if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then
+                echo "local_path/uci_housing_model/fluid_time_file not exist."
+                exit 1
+            fi
+            ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill
+            rm -rf remote_path/*
+            rm -rf local_path/*
+
+            # type: general
+            # remote_path: /
+            # remote_model_name: uci_housing.tar.gz
+            # local_tmp_path: ___tmp
+            # local_path: local_path
+            cd remote_path # pwd: /Serving/_monitor_test/remote_path
+            wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz
+            touch donefile
+            cd ..  # pwd: /Serving/_monitor_test
+            mkdir -p local_path/uci_housing_model
+            python -m paddle_serving_server_gpu.monitor \
+                    --type='general' --general_host='ftp://127.0.0.1:8000' \
+                    --remote_path='/' --remote_model_name='uci_housing.tar.gz' \
+                    --remote_donefile_name='donefile' --local_path='local_path' \
+                    --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \
+                    --local_tmp_path='___tmp' --unpacked_filename='uci_housing_model' \
+                    --interval='1' >/dev/null &
+            sleep 10
+            if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then
+                echo "local_path/uci_housing_model/fluid_time_file not exist."
+                exit 1
+            fi
+            ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill
+            rm -rf remote_path/*
+            rm -rf local_path/*
+
+            # type: general
+            # remote_path: /tmp_dir
+            # remote_model_name: uci_housing_model
+            # local_tmp_path: ___tmp
+            # local_path: local_path
+            mkdir -p remote_path/tmp_dir && cd remote_path/tmp_dir # pwd: /Serving/_monitor_test/remote_path/tmp_dir
+            wget --no-check-certificate https://paddle-serving.bj.bcebos.com/uci_housing.tar.gz
+            tar -xzf uci_housing.tar.gz
+            touch donefile
+            cd ../.. # pwd: /Serving/_monitor_test
+            mkdir -p local_path/uci_housing_model
+            python -m paddle_serving_server_gpu.monitor \
+                    --type='general' --general_host='ftp://127.0.0.1:8000' \
+                    --remote_path='/tmp_dir' --remote_model_name='uci_housing_model' \
+                    --remote_donefile_name='donefile' --local_path='local_path' \
+                    --local_model_name='uci_housing_model' --local_timestamp_file='fluid_time_file' \
+                    --local_tmp_path='___tmp' --interval='1' >/dev/null &
+            sleep 10
+            if [ ! -f "local_path/uci_housing_model/fluid_time_file" ]; then
+                echo "local_path/uci_housing_model/fluid_time_file not exist."
+                exit 1
+            fi
+            ps -ef | grep "monitor" | grep -v grep | awk '{print $2}' | xargs kill
+            rm -rf remote_path/*
+            rm -rf local_path/*
+
+            ps -ef | grep "pyftpdlib" | grep -v grep | awk '{print $2}' | xargs kill
+            ;;
+        *)
+            echo "error type"
+            exit 1
+            ;;
+    esac
+    cd .. # pwd: /Serving
+    rm -rf _monitor_test
+    echo "test monitor $TYPE finished as expected."
+}
+
 function main() {
     local TYPE=$1 # pwd: /
     init # pwd: /Serving
@@ -415,6 +666,7 @@ function main() {
     build_server $TYPE # pwd: /Serving
     build_app $TYPE # pwd: /Serving
     python_run_test $TYPE # pwd: /Serving
+    monitor_test $TYPE # pwd: /Serving
     echo "serving $TYPE part finished as expected."
 }