diff --git a/python/examples/imdb/get_data.sh b/python/examples/imdb/get_data.sh
new file mode 100644
index 0000000000000000000000000000000000000000..b8931380151aaf12bc5cd5bdb7608a18e4c6ed50
--- /dev/null
+++ b/python/examples/imdb/get_data.sh
@@ -0,0 +1,2 @@
+wget https://fleet.bj.bcebos.com/text_classification_data.tar.gz
+tar -zxvf text_classification_data.tar.gz
diff --git a/python/examples/imdb/imdb_reader.py b/python/examples/imdb/imdb_reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..def7ce2197bfd24bc4f17f97e5e4a1aa541bcabc
--- /dev/null
+++ b/python/examples/imdb/imdb_reader.py
@@ -0,0 +1,70 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os
+import paddle
+import re
+import paddle.fluid.incubate.data_generator as dg
+
+class IMDBDataset(dg.MultiSlotDataGenerator):
+    def load_resource(self, dictfile):
+        self._vocab = {}
+        wid = 0
+        with open(dictfile) as f:
+            for line in f:
+                self._vocab[line.strip()] = wid
+                wid += 1
+        self._unk_id = len(self._vocab)
+        self._pattern = re.compile(r'(;|,|\.|\?|!|\s|\(|\))')
+        self.return_value = ("words", [1, 2, 3, 4, 5, 6]), ("label", [0])
+
+    def get_words_and_label(self, line):
+        send = '|'.join(line.split('|')[:-1]).lower().replace("
",
+                                                              " ").strip()
+        label = [int(line.split('|')[-1])]
+        
+        words = [x for x in self._pattern.split(send) if x and x != " "]
+        feas = [
+            self._vocab[x] if x in self._vocab else self._unk_id for x in words
+        ]
+        return feas, label
+
+    def infer_reader(self, infer_filelist, batch, buf_size):
+        def local_iter():
+            for fname in infer_filelist:
+                with open(fname, "r") as fin:
+                    for line in fin:
+                        feas, label = self.get_words_and_label(line)
+                        yield feas, label
+        import paddle
+        batch_iter = paddle.batch(
+            paddle.reader.shuffle(local_iter, buf_size=buf_size),
+            batch_size=batch)
+        return batch_iter
+
+    def generate_sample(self, line):
+        def memory_iter():
+            for i in range(1000):
+                yield self.return_value
+        def data_iter():
+            feas, label = self.get_words_and_label(line)
+            yield ("words", feas), ("label", label)
+        return data_iter
+
+if __name__ == "__main__":
+    imdb = IMDBDataset()
+    imdb.load_resource("imdb.vocab")
+    imdb.run_from_stdin()
+
diff --git a/python/examples/imdb/inference.conf b/python/examples/imdb/inference.conf
deleted file mode 100644
index bbb056b8914e457e7efcbaa59be5e994f87a9ca1..0000000000000000000000000000000000000000
--- a/python/examples/imdb/inference.conf
+++ /dev/null
@@ -1,6 +0,0 @@
-2 3
-words 1 -1 0
-label 1 1 0
-cost mean_0.tmp_0
-acc accuracy_0.tmp_0
-prediction fc_1.tmp_2
diff --git a/python/examples/imdb/local_train.py b/python/examples/imdb/local_train.py
new file mode 100644
index 0000000000000000000000000000000000000000..52d6dfa2c26aaacbb12d197879ce69c701f82a9e
--- /dev/null
+++ b/python/examples/imdb/local_train.py
@@ -0,0 +1,68 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+import paddle
+import logging
+import paddle.fluid as fluid
+import paddle_serving as serving
+
+logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
+logger = logging.getLogger("fluid")
+logger.setLevel(logging.INFO)
+
+def load_vocab(filename):
+    vocab = {}
+    with open(filename) as f:
+        wid = 0
+        for line in f:
+            vocab[line.strip()] = wid
+            wid += 1
+    vocab[""] = len(vocab)
+    return vocab
+
+if __name__ == "__main__":
+    vocab = load_vocab('imdb.vocab')
+    dict_dim = len(vocab)
+
+    data = fluid.layers.data(name="words", shape=[1], dtype="int64", lod_level=1)
+    label = fluid.layers.data(name="label", shape=[1], dtype="int64")
+
+    dataset = fluid.DatasetFactory().create_dataset()
+    filelist = ["train_data/%s" % x for x in os.listdir("train_data")]
+    dataset.set_use_var([data, label])
+    pipe_command = "python imdb_reader.py"
+    dataset.set_pipe_command(pipe_command)
+    dataset.set_batch_size(4)
+    dataset.set_filelist(filelist)
+    dataset.set_thread(10)
+    from nets import cnn_net
+    avg_cost, acc, prediction = cnn_net(data, label, dict_dim)
+    optimizer = fluid.optimizer.SGD(learning_rate=0.01)
+    optimizer.minimize(avg_cost)
+
+    exe = fluid.Executor(fluid.CPUPlace())
+    exe.run(fluid.default_startup_program())
+    epochs = 30
+    save_dirname = "cnn_model"
+
+    for i in range(epochs):
+        exe.train_from_dataset(program=fluid.default_main_program(),
+                               dataset=dataset, debug=False)
+        logger.info("TRAIN --> pass: {}".format(i))
+        fluid.io.save_inference_model("%s/epoch%d.model" % (save_dirname, i),
+                                      [data.name, label.name], [acc], exe)
+        serving.save_model("%s/epoch%d.model" % (save_dirname, i), "client_config{}".format(i),
+                           {"words": data, "label": label},
+                           {"acc": acc, "cost": avg_cost, "prediction": prediction})
diff --git a/python/examples/imdb/nets.py b/python/examples/imdb/nets.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b451d16bca7aca1b505a464fe85bcf0568964c6
--- /dev/null
+++ b/python/examples/imdb/nets.py
@@ -0,0 +1,125 @@
+import sys
+import time
+import numpy as np
+
+import paddle
+import paddle.fluid as fluid
+
+
+def bow_net(data,
+            label,
+            dict_dim,
+            emb_dim=128,
+            hid_dim=128,
+            hid_dim2=96,
+            class_dim=2):
+    """
+    bow net
+    """
+    emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim], is_sparse=True)
+    bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
+    bow_tanh = fluid.layers.tanh(bow)
+    fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
+    fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh")
+    prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax")
+    cost = fluid.layers.cross_entropy(input=prediction, label=label)
+    avg_cost = fluid.layers.mean(x=cost)
+    acc = fluid.layers.accuracy(input=prediction, label=label)
+
+    return avg_cost, acc, prediction
+
+
+def cnn_net(data,
+            label,
+            dict_dim,
+            emb_dim=128,
+            hid_dim=128,
+            hid_dim2=96,
+            class_dim=2,
+            win_size=3):
+    """
+    conv net
+    """
+    emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim], is_sparse=True)
+
+    conv_3 = fluid.nets.sequence_conv_pool(
+        input=emb,
+        num_filters=hid_dim,
+        filter_size=win_size,
+        act="tanh",
+        pool_type="max")
+
+    fc_1 = fluid.layers.fc(input=[conv_3], size=hid_dim2)
+
+    prediction = fluid.layers.fc(input=[fc_1], size=class_dim, act="softmax")
+    cost = fluid.layers.cross_entropy(input=prediction, label=label)
+    avg_cost = fluid.layers.mean(x=cost)
+    acc = fluid.layers.accuracy(input=prediction, label=label)
+
+    return avg_cost, acc, prediction
+
+
+def lstm_net(data,
+             label,
+             dict_dim,
+             emb_dim=128,
+             hid_dim=128,
+             hid_dim2=96,
+             class_dim=2,
+             emb_lr=30.0):
+    """
+    lstm net
+    """
+    emb = fluid.layers.embedding(
+        input=data,
+        size=[dict_dim, emb_dim],
+        param_attr=fluid.ParamAttr(learning_rate=emb_lr),
+        is_sparse=True)
+
+    fc0 = fluid.layers.fc(input=emb, size=hid_dim * 4)
+
+    lstm_h, c = fluid.layers.dynamic_lstm(
+        input=fc0, size=hid_dim * 4, is_reverse=False)
+
+    lstm_max = fluid.layers.sequence_pool(input=lstm_h, pool_type='max')
+    lstm_max_tanh = fluid.layers.tanh(lstm_max)
+
+    fc1 = fluid.layers.fc(input=lstm_max_tanh, size=hid_dim2, act='tanh')
+
+    prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
+
+    cost = fluid.layers.cross_entropy(input=prediction, label=label)
+    avg_cost = fluid.layers.mean(x=cost)
+    acc = fluid.layers.accuracy(input=prediction, label=label)
+
+    return avg_cost, acc, prediction
+
+
+def gru_net(data,
+            label,
+            dict_dim,
+            emb_dim=128,
+            hid_dim=128,
+            hid_dim2=96,
+            class_dim=2,
+            emb_lr=400.0):
+    """
+    gru net
+    """
+    emb = fluid.layers.embedding(
+        input=data,
+        size=[dict_dim, emb_dim],
+        param_attr=fluid.ParamAttr(learning_rate=emb_lr))
+
+    fc0 = fluid.layers.fc(input=emb, size=hid_dim * 3)
+    gru_h = fluid.layers.dynamic_gru(input=fc0, size=hid_dim, is_reverse=False)
+    gru_max = fluid.layers.sequence_pool(input=gru_h, pool_type='max')
+    gru_max_tanh = fluid.layers.tanh(gru_max)
+    fc1 = fluid.layers.fc(input=gru_max_tanh, size=hid_dim2, act='tanh')
+    prediction = fluid.layers.fc(input=fc1, size=class_dim, act='softmax')
+
+    cost = fluid.layers.cross_entropy(input=prediction, label=label)
+    avg_cost = fluid.layers.mean(x=cost)
+    acc = fluid.layers.accuracy(input=prediction, label=label)
+
+    return avg_cost, acc, prediction
diff --git a/python/paddle_serving/__init__.py b/python/paddle_serving/__init__.py
index 1b1ac92198b39c3d00c50af2e68baf68ca9ff75e..f8cab8c47308f39fe3963883e3090f33080ebd83 100644
--- a/python/paddle_serving/__init__.py
+++ b/python/paddle_serving/__init__.py
@@ -12,3 +12,4 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from .serving_client import Client
+from .io import save_model
diff --git a/python/paddle_serving/io/__init__.py b/python/paddle_serving/io/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e38fdabad3b56f16bb98b98241884b963704fcec
--- /dev/null
+++ b/python/paddle_serving/io/__init__.py
@@ -0,0 +1,64 @@
+#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from paddle.fluid import Executor
+from paddle.fluid.compiler import CompiledProgram
+from paddle.fluid.framework import Program
+from paddle.fluid.framework import default_main_program
+from paddle.fluid import CPUPlace
+from paddle.fluid.io import save_persistables
+import os
+
+def save_model(server_model_folder,
+               client_config_folder,
+               feed_var_dict,
+               fetch_var_dict,
+               main_program=None):
+    if main_program is None:
+        main_program = default_main_program()
+    elif isinstance(main_program, CompiledProgram):
+        main_program = main_program._program
+        if main_program is None:
+            raise TypeError("program should be as Program type or None")
+    if not isinstance(main_program, Program):
+        raise TypeError("program should be as Program type or None")
+
+    executor = Executor(place=CPUPlace())
+
+    save_persistables(executor, server_model_folder,
+                      main_program)
+
+    cmd = "mkdir -p {}".format(client_config_folder)
+    os.system(cmd)
+    with open("{}/client.conf".format(client_config_folder), "w") as fout:
+        fout.write("{} {}\n".format(len(feed_var_dict), len(fetch_var_dict)))
+        for key in feed_var_dict:
+            fout.write("{}".format(key))
+            if feed_var_dict[key].lod_level == 1:
+                fout.write(" 1 -1\n")
+            elif feed_var_dict[key].lod_level == 0:
+                fout.write(" {}".format(len(feed_var_dict[key].shape)))
+                for dim in feed_var_dict[key].shape:
+                    fout.write(" {}".format(dim))
+                fout.write("\n")
+        for key in fetch_var_dict:
+            fout.write("{} {}\n".format(key, fetch_var_dict[key].name))
+
+    cmd = "cp {}/client.conf {}/server.conf".format(
+        client_config_folder, server_model_folder)
+    os.system(cmd)
+
+    
+
diff --git a/python/setup.py.in b/python/setup.py.in
index 313f78092b36ac49b403a7df8dc0ef5dfe5d8052..90d2fcd59ff83c3f16f01830f690b83277e2c326 100644
--- a/python/setup.py.in
+++ b/python/setup.py.in
@@ -34,12 +34,15 @@ REQUIRED_PACKAGES = [
 
 packages=['paddle_serving',
           'paddle_serving.serving_client',
-          'paddle_serving.proto']
+          'paddle_serving.proto',
+          'paddle_serving.io']
 package_data={'paddle_serving.serving_client': ['serving_client.so']}
 package_dir={'paddle_serving.serving_client':
              '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving/serving_client',
              'paddle_serving.proto':
-             '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving/proto'}
+             '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving/proto',
+             'paddle_serving.io':
+             '${PADDLE_SERVING_BINARY_DIR}/python/paddle_serving/io'}
 
 setup(
     name='paddle-serving-client',