diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc
index 99c599680a486ee5a0fd03075dbfd8d825abc249..f4fdf4880bcf50ce22bb024e6e8fdd75d36655ba 100644
--- a/paddle/fluid/distributed/fleet.cc
+++ b/paddle/fluid/distributed/fleet.cc
@@ -472,9 +472,15 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) {
   }
 }
 
-void FleetWrapper::ShrinkSparseTable(int table_id) {
-  auto ret = pserver_ptr_->_worker_ptr->shrink(table_id);
+void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) {
+  auto* communicator = Communicator::GetInstance();
+  auto ret =
+      communicator->_worker_ptr->shrink(table_id, std::to_string(threshold));
   ret.wait();
+  int32_t err_code = ret.get();
+  if (err_code == -1) {
+    LOG(ERROR) << "shrink sparse table stat failed";
+  }
 }
 
 void FleetWrapper::ClearModel() {
diff --git a/paddle/fluid/distributed/fleet.h b/paddle/fluid/distributed/fleet.h
index 25c4e3ef8b8e6dd8c00062ef00be2ea1c328e16d..ac566606ddcb4024eeaf7b846c894f7f5cdafa82 100644
--- a/paddle/fluid/distributed/fleet.h
+++ b/paddle/fluid/distributed/fleet.h
@@ -217,7 +217,7 @@ class FleetWrapper {
   // clear one table
   void ClearOneTable(const uint64_t table_id);
   // shrink sparse table
-  void ShrinkSparseTable(int table_id);
+  void ShrinkSparseTable(int table_id, int threshold);
   // shrink dense table
   void ShrinkDenseTable(int table_id, Scope* scope,
                         std::vector<std::string> var_list, float decay,
diff --git a/paddle/fluid/distributed/service/brpc_ps_client.cc b/paddle/fluid/distributed/service/brpc_ps_client.cc
index 39e38c22020e0a902ec538a82f73b19da6fe1db0..163526fe3b28c91f36e2670d1974b520ef3bf66a 100644
--- a/paddle/fluid/distributed/service/brpc_ps_client.cc
+++ b/paddle/fluid/distributed/service/brpc_ps_client.cc
@@ -345,8 +345,9 @@ std::future<int32_t> BrpcPsClient::send_save_cmd(
   return fut;
 }
 
-std::future<int32_t> BrpcPsClient::shrink(uint32_t table_id) {
-  return send_cmd(table_id, PS_SHRINK_TABLE, {std::string("1")});
+std::future<int32_t> BrpcPsClient::shrink(uint32_t table_id,
+                                          const std::string threshold) {
+  return send_cmd(table_id, PS_SHRINK_TABLE, {threshold});
 }
 
 std::future<int32_t> BrpcPsClient::load(const std::string &epoch,
diff --git a/paddle/fluid/distributed/service/brpc_ps_client.h b/paddle/fluid/distributed/service/brpc_ps_client.h
index e4d9e537640f602e65a3dcde925672d4e6755c53..8f9d2653864d1c7fd1801632a6c84edb1bc04ccf 100644
--- a/paddle/fluid/distributed/service/brpc_ps_client.h
+++ b/paddle/fluid/distributed/service/brpc_ps_client.h
@@ -115,7 +115,8 @@ class BrpcPsClient : public PSClient {
   }
   virtual int32_t create_client2client_connection(
       int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
-  virtual std::future<int32_t> shrink(uint32_t table_id) override;
+  virtual std::future<int32_t> shrink(uint32_t table_id,
+                                      const std::string threshold) override;
   virtual std::future<int32_t> load(const std::string &epoch,
                                     const std::string &mode) override;
   virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,
diff --git a/paddle/fluid/distributed/service/brpc_ps_server.cc b/paddle/fluid/distributed/service/brpc_ps_server.cc
index 110397485c52d59d57fe86542bb75da8ef6c2d4c..32de11847387b6573707b330dd042a81c04857ed 100644
--- a/paddle/fluid/distributed/service/brpc_ps_server.cc
+++ b/paddle/fluid/distributed/service/brpc_ps_server.cc
@@ -463,6 +463,8 @@ int32_t BrpcPsService::save_one_table(Table *table,
   table->flush();
 
   int32_t feasign_size = 0;
+
+  VLOG(0) << "save one table " << request.params(0) << " " << request.params(1);
   feasign_size = table->save(request.params(0), request.params(1));
   if (feasign_size < 0) {
     set_response_code(response, -1, "table save failed");
@@ -494,10 +496,18 @@ int32_t BrpcPsService::shrink_table(Table *table,
                                     PsResponseMessage &response,
                                     brpc::Controller *cntl) {
   CHECK_TABLE_EXIST(table, request, response)
+  if (request.params_size() < 1) {
+    set_response_code(
+        response, -1,
+        "PsRequestMessage.datas is requeired at least 1, threshold");
+    return -1;
+  }
   table->flush();
-  if (table->shrink() != 0) {
+  if (table->shrink(request.params(0)) != 0) {
     set_response_code(response, -1, "table shrink failed");
+    return -1;
   }
+  VLOG(0) << "Pserver Shrink Finished";
   return 0;
 }
 
diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h
index 22f560f1224a6bc2e0cbfc77d66d165fe3468393..50f5802c63a2538566988f75e3c098bd01785294 100644
--- a/paddle/fluid/distributed/service/ps_client.h
+++ b/paddle/fluid/distributed/service/ps_client.h
@@ -75,7 +75,8 @@ class PSClient {
       int max_retry) = 0;
 
   // 触发table数据退场
-  virtual std::future<int32_t> shrink(uint32_t table_id) = 0;
+  virtual std::future<int32_t> shrink(uint32_t table_id,
+                                      const std::string threshold) = 0;
 
   // 全量table进行数据load
   virtual std::future<int32_t> load(const std::string &epoch,
diff --git a/paddle/fluid/distributed/table/common_dense_table.h b/paddle/fluid/distributed/table/common_dense_table.h
index 4b9f4900b8f003878241c235d4ad9a6772bd5593..e363afc45c54c366eebec4977bbc270b4c650680 100644
--- a/paddle/fluid/distributed/table/common_dense_table.h
+++ b/paddle/fluid/distributed/table/common_dense_table.h
@@ -60,7 +60,7 @@ class CommonDenseTable : public DenseTable {
   }
 
   virtual int32_t flush() override { return 0; }
-  virtual int32_t shrink() override { return 0; }
+  virtual int32_t shrink(const std::string& param) override { return 0; }
   virtual void clear() override { return; }
 
  protected:
diff --git a/paddle/fluid/distributed/table/common_sparse_table.cc b/paddle/fluid/distributed/table/common_sparse_table.cc
index fbfb7280c95501a97f6739078a1d1c1e109a9c7c..e0b331bbde2b2fd93027c3488f8d1ade17e361a6 100644
--- a/paddle/fluid/distributed/table/common_sparse_table.cc
+++ b/paddle/fluid/distributed/table/common_sparse_table.cc
@@ -26,9 +26,12 @@ class ValueBlock;
 }  // namespace paddle
 
 #define PSERVER_SAVE_SUFFIX "_txt"
+
 namespace paddle {
 namespace distributed {
 
+enum SaveMode { all, base, delta };
+
 struct Meta {
   std::string param;
   int shard_id;
@@ -98,12 +101,9 @@ struct Meta {
 
 void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
                   std::vector<std::vector<float>>* values) {
-  PADDLE_ENFORCE_EQ(columns.size(), 2,
-                    paddle::platform::errors::InvalidArgument(
-                        "The data format does not meet the requirements. It "
-                        "should look like feasign_id \t params."));
-
-  auto load_values = paddle::string::split_string<std::string>(columns[1], ",");
+  auto colunmn_size = columns.size();
+  auto load_values =
+      paddle::string::split_string<std::string>(columns[colunmn_size - 1], ",");
   values->reserve(meta.names.size());
 
   int offset = 0;
@@ -125,11 +125,18 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
 
 int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
                    const int mode) {
+  int64_t not_save_num = 0;
   for (auto value : block->values_) {
+    if (mode == SaveMode::delta && !value.second->need_save_) {
+      not_save_num++;
+      continue;
+    }
+
     auto* vs = value.second->data_.data();
     std::stringstream ss;
     auto id = value.first;
-    ss << id << "\t";
+    ss << id << "\t" << value.second->count_ << "\t"
+       << value.second->unseen_days_ << "\t" << value.second->is_entry_ << "\t";
 
     for (int i = 0; i < block->value_length_; i++) {
       ss << vs[i];
@@ -139,9 +146,13 @@ int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
     ss << "\n";
 
     os->write(ss.str().c_str(), sizeof(char) * ss.str().size());
+
+    if (mode == SaveMode::base || mode == SaveMode::delta) {
+      value.second->need_save_ = false;
+    }
   }
 
-  return block->values_.size();
+  return block->values_.size() - not_save_num;
 }
 
 int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
@@ -169,8 +180,21 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
 
     std::vector<std::vector<float>> kvalues;
     ProcessALine(values, meta, &kvalues);
-    // warning: need fix
-    block->Init(id);
+
+    block->Init(id, false);
+
+    auto value_instant = block->GetValue(id);
+    if (values.size() == 5) {
+      value_instant->count_ = std::stoi(values[1]);
+      value_instant->unseen_days_ = std::stoi(values[2]);
+      value_instant->is_entry_ = static_cast<bool>(std::stoi(values[3]));
+    }
+
+    std::vector<float*> block_values = block->Get(id, meta.names, meta.dims);
+    auto blas = GetBlas<float>();
+    for (int x = 0; x < meta.names.size(); ++x) {
+      blas.VCOPY(meta.dims[x], kvalues[x].data(), block_values[x]);
+    }
   }
 
   return 0;
@@ -397,7 +421,7 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values, const uint64_t* keys,
           for (int i = 0; i < offsets.size(); ++i) {
             auto offset = offsets[i];
             auto id = keys[offset];
-            auto* value = block->InitFromInitializer(id);
+            auto* value = block->Init(id);
             std::copy_n(value + param_offset_, param_dim_,
                         pull_values + param_dim_ * offset);
           }
@@ -492,9 +516,10 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
           for (int i = 0; i < offsets.size(); ++i) {
             auto offset = offsets[i];
             auto id = keys[offset];
-            auto* value = block->InitFromInitializer(id);
+            auto* value = block->Init(id, false);
             std::copy_n(values + param_dim_ * offset, param_dim_,
                         value + param_offset_);
+            block->SetEntry(id, true);
           }
           return 0;
         });
@@ -509,10 +534,20 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
 
 int32_t CommonSparseTable::flush() { return 0; }
 
-int32_t CommonSparseTable::shrink() {
-  VLOG(0) << "shrink coming soon";
+int32_t CommonSparseTable::shrink(const std::string& param) {
+  rwlock_->WRLock();
+  int threshold = std::stoi(param);
+  VLOG(0) << "sparse table shrink: " << threshold;
+
+  for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
+    // shrink
+    VLOG(0) << shard_id << " " << task_pool_size_ << " begin shrink";
+    shard_values_[shard_id]->Shrink(threshold);
+  }
+  rwlock_->UNLock();
   return 0;
 }
+
 void CommonSparseTable::clear() { VLOG(0) << "clear coming soon"; }
 
 }  // namespace distributed
diff --git a/paddle/fluid/distributed/table/common_sparse_table.h b/paddle/fluid/distributed/table/common_sparse_table.h
index d8df0f663cfa1cc96f655a9c3ae673cd8ab33f43..98cbf2b4a21057f64a5d510158907df1de393925 100644
--- a/paddle/fluid/distributed/table/common_sparse_table.h
+++ b/paddle/fluid/distributed/table/common_sparse_table.h
@@ -75,7 +75,7 @@ class CommonSparseTable : public SparseTable {
 
   virtual int32_t pour();
   virtual int32_t flush();
-  virtual int32_t shrink();
+  virtual int32_t shrink(const std::string& param);
   virtual void clear();
 
  protected:
diff --git a/paddle/fluid/distributed/table/common_table.h b/paddle/fluid/distributed/table/common_table.h
index d37e6677e634d7da93b661cf02389c2f4abee19a..034769e021207cce0277fad6739566bf8da1fa67 100644
--- a/paddle/fluid/distributed/table/common_table.h
+++ b/paddle/fluid/distributed/table/common_table.h
@@ -108,7 +108,7 @@ class DenseTable : public Table {
   int32_t push_dense_param(const float *values, size_t num) override {
     return 0;
   }
-  int32_t shrink() override { return 0; }
+  int32_t shrink(const std::string &param) override { return 0; }
 };
 
 class BarrierTable : public Table {
@@ -133,7 +133,7 @@ class BarrierTable : public Table {
   int32_t push_dense_param(const float *values, size_t num) override {
     return 0;
   }
-  int32_t shrink() override { return 0; }
+  int32_t shrink(const std::string &param) override { return 0; }
   virtual void clear(){};
   virtual int32_t flush() { return 0; };
   virtual int32_t load(const std::string &path, const std::string &param) {
diff --git a/paddle/fluid/distributed/table/depends/large_scale_kv.h b/paddle/fluid/distributed/table/depends/large_scale_kv.h
index 55f8489b08cba04a132bba81c72ac34cf28a8ce2..1cfbf2a5ffd2cecd96c34f9597d601a7f3be42bd 100644
--- a/paddle/fluid/distributed/table/depends/large_scale_kv.h
+++ b/paddle/fluid/distributed/table/depends/large_scale_kv.h
@@ -47,43 +47,34 @@ namespace distributed {
 
 enum Mode { training, infer };
 
-template <typename T>
-inline bool entry(const int count, const T threshold);
-
-template <>
-inline bool entry<std::string>(const int count, const std::string threshold) {
-  return true;
-}
-
-template <>
-inline bool entry<int>(const int count, const int threshold) {
-  return count >= threshold;
-}
-
-template <>
-inline bool entry<float>(const int count, const float threshold) {
-  UniformInitializer uniform = UniformInitializer({"0", "0", "1"});
-  return uniform.GetValue() >= threshold;
-}
-
 struct VALUE {
   explicit VALUE(size_t length)
       : length_(length),
-        count_(1),
+        count_(0),
         unseen_days_(0),
-        seen_after_last_save_(true),
-        is_entry_(true) {
+        need_save_(false),
+        is_entry_(false) {
     data_.resize(length);
+    memset(data_.data(), 0, sizeof(float) * length);
   }
 
   size_t length_;
   std::vector<float> data_;
   int count_;
-  int unseen_days_;
-  bool seen_after_last_save_;
-  bool is_entry_;
+  int unseen_days_;  // use to check knock-out
+  bool need_save_;   // whether need to save
+  bool is_entry_;    // whether knock-in
 };
 
+inline bool count_entry(std::shared_ptr<VALUE> value, int threshold) {
+  return value->count_ >= threshold;
+}
+
+inline bool probility_entry(std::shared_ptr<VALUE> value, float threshold) {
+  UniformInitializer uniform = UniformInitializer({"0", "0", "1"});
+  return uniform.GetValue() >= threshold;
+}
+
 class ValueBlock {
  public:
   explicit ValueBlock(const std::vector<std::string> &value_names,
@@ -102,21 +93,21 @@ class ValueBlock {
 
     // for Entry
     {
-      if (entry_attr == "none") {
-        has_entry_ = false;
+      auto slices = string::split_string<std::string>(entry_attr, "&");
+      if (slices[0] == "none") {
+        entry_func_ = std::bind(&count_entry, std::placeholders::_1, 0);
+      } else if (slices[0] == "count_filter") {
+        int threshold = std::stoi(slices[1]);
+        entry_func_ = std::bind(&count_entry, std::placeholders::_1, threshold);
+      } else if (slices[0] == "probability") {
+        float threshold = std::stof(slices[1]);
         entry_func_ =
-            std::bind(entry<std::string>, std::placeholders::_1, "none");
+            std::bind(&probility_entry, std::placeholders::_1, threshold);
       } else {
-        has_entry_ = true;
-        auto slices = string::split_string<std::string>(entry_attr, "&");
-        if (slices[0] == "count_filter") {
-          int threshold = std::stoi(slices[1]);
-          entry_func_ = std::bind(entry<int>, std::placeholders::_1, threshold);
-        } else if (slices[0] == "probability") {
-          float threshold = std::stof(slices[1]);
-          entry_func_ =
-              std::bind(entry<float>, std::placeholders::_1, threshold);
-        }
+        PADDLE_THROW(platform::errors::InvalidArgument(
+            "Not supported Entry Type : %s, Only support [count_filter, "
+            "probability]",
+            slices[0]));
       }
     }
 
@@ -147,58 +138,87 @@ class ValueBlock {
 
   ~ValueBlock() {}
 
-  float *Init(const uint64_t &id) {
-    auto value = std::make_shared<VALUE>(value_length_);
-    for (int x = 0; x < value_names_.size(); ++x) {
-      initializers_[x]->GetValue(value->data_.data() + value_offsets_[x],
-                                 value_dims_[x]);
-    }
-    values_[id] = value;
-    return value->data_.data();
-  }
-
   std::vector<float *> Get(const uint64_t &id,
-                           const std::vector<std::string> &value_names) {
+                           const std::vector<std::string> &value_names,
+                           const std::vector<int> &value_dims) {
     auto pts = std::vector<float *>();
     pts.reserve(value_names.size());
     auto &values = values_.at(id);
     for (int i = 0; i < static_cast<int>(value_names.size()); i++) {
+      PADDLE_ENFORCE_EQ(
+          value_dims[i], value_dims_[i],
+          platform::errors::InvalidArgument("value dims is not match"));
       pts.push_back(values->data_.data() +
                     value_offsets_.at(value_idx_.at(value_names[i])));
     }
     return pts;
   }
 
-  float *Get(const uint64_t &id) {
-    auto pts = std::vector<std::vector<float> *>();
-    auto &values = values_.at(id);
+  // pull
+  float *Init(const uint64_t &id, const bool with_update = true) {
+    if (!Has(id)) {
+      values_[id] = std::make_shared<VALUE>(value_length_);
+    }
+
+    auto &value = values_.at(id);
 
-    return values->data_.data();
+    if (with_update) {
+      AttrUpdate(value);
+    }
+
+    return value->data_.data();
   }
 
-  float *InitFromInitializer(const uint64_t &id) {
-    if (Has(id)) {
-      if (has_entry_) {
-        Update(id);
+  void AttrUpdate(std::shared_ptr<VALUE> value) {
+    // update state
+    value->unseen_days_ = 0;
+    ++value->count_;
+
+    if (!value->is_entry_) {
+      value->is_entry_ = entry_func_(value);
+      if (value->is_entry_) {
+        // initialize
+        for (int x = 0; x < value_names_.size(); ++x) {
+          initializers_[x]->GetValue(value->data_.data() + value_offsets_[x],
+                                     value_dims_[x]);
+        }
       }
-      return Get(id);
     }
-    return Init(id);
+
+    value->need_save_ = true;
+    return;
   }
 
+  // dont jude if (has(id))
+  float *Get(const uint64_t &id) {
+    auto &value = values_.at(id);
+    return value->data_.data();
+  }
+
+  // for load, to reset count, unseen_days
+  std::shared_ptr<VALUE> GetValue(const uint64_t &id) { return values_.at(id); }
+
   bool GetEntry(const uint64_t &id) {
-    auto value = values_.at(id);
+    auto &value = values_.at(id);
     return value->is_entry_;
   }
 
-  void Update(const uint64_t id) {
-    auto value = values_.at(id);
-    value->unseen_days_ = 0;
-    auto count = ++value->count_;
+  void SetEntry(const uint64_t &id, const bool state) {
+    auto &value = values_.at(id);
+    value->is_entry_ = state;
+  }
 
-    if (!value->is_entry_) {
-      value->is_entry_ = entry_func_(count);
+  void Shrink(const int threshold) {
+    for (auto iter = values_.begin(); iter != values_.end();) {
+      auto &value = iter->second;
+      value->unseen_days_++;
+      if (value->unseen_days_ >= threshold) {
+        iter = values_.erase(iter);
+      } else {
+        ++iter;
+      }
     }
+    return;
   }
 
  private:
@@ -221,8 +241,7 @@ class ValueBlock {
   const std::vector<int> &value_offsets_;
   const std::unordered_map<std::string, int> &value_idx_;
 
-  bool has_entry_ = false;
-  std::function<bool(uint64_t)> entry_func_;
+  std::function<bool(std::shared_ptr<VALUE>)> entry_func_;
   std::vector<std::shared_ptr<Initializer>> initializers_;
 };
 
diff --git a/paddle/fluid/distributed/table/depends/sparse.h b/paddle/fluid/distributed/table/depends/sparse.h
index 4ee753fc75a3f6f39d2f4bafe3310735922a52ac..672d6e7d396874b5cc5a296f15e3842a3233410b 100644
--- a/paddle/fluid/distributed/table/depends/sparse.h
+++ b/paddle/fluid/distributed/table/depends/sparse.h
@@ -76,6 +76,7 @@ class SSUM : public SparseOptimizer {
     auto blas = GetBlas<float>();
     for (auto x : offsets) {
       auto id = keys[x];
+      if (!block->GetEntry(id)) continue;
       auto* value = block->Get(id);
       float* param = value + param_offset;
       blas.VADD(update_numel, update_values + x * update_numel, param, param);
@@ -105,6 +106,7 @@ class SSGD : public SparseOptimizer {
     auto blas = GetBlas<float>();
     for (auto x : offsets) {
       auto id = keys[x];
+      if (!block->GetEntry(id)) continue;
       auto* value = block->Get(id);
 
       float learning_rate = *(global_learning_rate_) * (value + lr_offset)[0];
@@ -161,6 +163,7 @@ class SAdam : public SparseOptimizer {
     auto blas = GetBlas<float>();
     for (auto x : offsets) {
       auto id = keys[x];
+      if (!block->GetEntry(id)) continue;
       auto* values = block->Get(id);
       float lr_ = *(global_learning_rate_) * (values + lr_offset)[0];
       VLOG(4) << "SAdam LearningRate: " << lr_;
diff --git a/paddle/fluid/distributed/table/table.h b/paddle/fluid/distributed/table/table.h
index 1bfedb53ab83d331d32b3ce828b0c1493c0ccc33..65c99d2bbd40d4567f49eb84bd84173a0a3fee0b 100644
--- a/paddle/fluid/distributed/table/table.h
+++ b/paddle/fluid/distributed/table/table.h
@@ -90,7 +90,7 @@ class Table {
 
   virtual void clear() = 0;
   virtual int32_t flush() = 0;
-  virtual int32_t shrink() = 0;
+  virtual int32_t shrink(const std::string &param) = 0;
 
   //指定加载路径
   virtual int32_t load(const std::string &path,
diff --git a/paddle/fluid/distributed/table/tensor_table.h b/paddle/fluid/distributed/table/tensor_table.h
index a57a49d9bd70e45610d05445dffccf3cfaa56fc9..1a8f1a9cd9adb841c3ed1fcf849a3a293c47cc52 100644
--- a/paddle/fluid/distributed/table/tensor_table.h
+++ b/paddle/fluid/distributed/table/tensor_table.h
@@ -60,7 +60,7 @@ class TensorTable : public Table {
                       size_t num) override {
     return 0;
   }
-  int32_t shrink() override { return 0; }
+  int32_t shrink(const std::string &param) override { return 0; }
 
   virtual void *get_shard(size_t shard_idx) { return 0; }
 
@@ -110,7 +110,7 @@ class DenseTensorTable : public TensorTable {
                       size_t num) override {
     return 0;
   }
-  int32_t shrink() override { return 0; }
+  int32_t shrink(const std::string &param) override { return 0; }
 
   virtual void *get_shard(size_t shard_idx) { return 0; }
 
@@ -166,7 +166,7 @@ class GlobalStepTable : public DenseTensorTable {
                       size_t num) override {
     return 0;
   }
-  int32_t shrink() override { return 0; }
+  int32_t shrink(const std::string &param) override { return 0; }
 
   virtual void *get_shard(size_t shard_idx) { return 0; }
 
diff --git a/paddle/fluid/pybind/fleet_py.cc b/paddle/fluid/pybind/fleet_py.cc
index 4777951d82c5e635e40fc2784d32721718067df5..ba716fb3b550ac0bc0bd09b362248de5904edc7a 100644
--- a/paddle/fluid/pybind/fleet_py.cc
+++ b/paddle/fluid/pybind/fleet_py.cc
@@ -62,7 +62,8 @@ void BindDistFleetWrapper(py::module* m) {
       .def("sparse_table_stat", &FleetWrapper::PrintTableStat)
       .def("stop_server", &FleetWrapper::StopServer)
       .def("stop_worker", &FleetWrapper::FinalizeWorker)
-      .def("barrier", &FleetWrapper::BarrierWithTable);
+      .def("barrier", &FleetWrapper::BarrierWithTable)
+      .def("shrink_sparse_table", &FleetWrapper::ShrinkSparseTable);
 }
 
 void BindPSHost(py::module* m) {
diff --git a/python/paddle/distributed/fleet/__init__.py b/python/paddle/distributed/fleet/__init__.py
index 0b7e8da101bba5e6713ad99d1f517b3eca33d755..bd8492ecfa7ee75307b9ef70446271209a2ffb69 100644
--- a/python/paddle/distributed/fleet/__init__.py
+++ b/python/paddle/distributed/fleet/__init__.py
@@ -63,3 +63,4 @@ set_lr = fleet.set_lr
 get_lr = fleet.get_lr
 state_dict = fleet.state_dict
 set_state_dict = fleet.set_state_dict
+shrink = fleet.shrink
diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py
index f4d62b9bf1be09b0418838c3284ee080ab833a83..f4075e92c4c4488ae2ffb8a2c14b34eeee35123b 100644
--- a/python/paddle/distributed/fleet/base/fleet_base.py
+++ b/python/paddle/distributed/fleet/base/fleet_base.py
@@ -521,7 +521,8 @@ class Fleet(object):
                              feeded_var_names,
                              target_vars,
                              main_program=None,
-                             export_for_deployment=True):
+                             export_for_deployment=True,
+                             mode=0):
         """
         save inference model for inference.
 
@@ -544,7 +545,7 @@ class Fleet(object):
 
         self._runtime_handle._save_inference_model(
             executor, dirname, feeded_var_names, target_vars, main_program,
-            export_for_deployment)
+            export_for_deployment, mode)
 
     def save_persistables(self, executor, dirname, main_program=None, mode=0):
         """
@@ -591,6 +592,9 @@ class Fleet(object):
         self._runtime_handle._save_persistables(executor, dirname, main_program,
                                                 mode)
 
+    def shrink(self, threshold):
+        self._runtime_handle._shrink(threshold)
+
     def distributed_optimizer(self, optimizer, strategy=None):
         """
         Optimizer for distributed training.
diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py
index dc78e1ce485e0e7e662ac79f68a11de085e994f2..91a70bd3f39561909fa918d9fb2b25248ec7dbfb 100644
--- a/python/paddle/distributed/fleet/runtime/the_one_ps.py
+++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py
@@ -946,7 +946,8 @@ class TheOnePSRuntime(RuntimeBase):
                                            feeded_var_names,
                                            target_vars,
                                            main_program=None,
-                                           export_for_deployment=True):
+                                           export_for_deployment=True,
+                                           mode=0):
         """
         Prune the given `main_program` to build a new program especially for inference,
         and then save it and all related parameters to given `dirname` by the `executor`.
@@ -983,10 +984,25 @@ class TheOnePSRuntime(RuntimeBase):
 
             program = Program.parse_from_string(program_desc_str)
             program._copy_dist_param_info_from(fluid.default_main_program())
-            self._ps_inference_save_persistables(executor, dirname, program)
+            self._ps_inference_save_persistables(executor, dirname, program,
+                                                 mode)
 
     def _save_inference_model(self, *args, **kwargs):
         self._ps_inference_save_inference_model(*args, **kwargs)
 
     def _save_persistables(self, *args, **kwargs):
         self._ps_inference_save_persistables(*args, **kwargs)
+
+    def _shrink(self, threshold):
+        import paddle.distributed.fleet as fleet
+        fleet.util.barrier()
+        if self.role_maker._is_first_worker():
+            sparses = self.compiled_strategy.get_the_one_recv_context(
+                is_dense=False,
+                split_dense_table=self.role_maker.
+                _is_heter_parameter_server_mode,
+                use_origin_program=True)
+
+            for id, names in sparses.items():
+                self._worker.shrink_sparse_table(id, threshold)
+        fleet.util.barrier()
diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps3.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps3.py
index d1740f9d96f515d00697ad99f8d9077c743f3c17..aa7975d2b8bef2bc645567d7c3594fa086996919 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ps3.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ps3.py
@@ -65,7 +65,7 @@ class TestPSPassWithBow(unittest.TestCase):
             return avg_cost
 
         is_distributed = False
-        is_sparse = True
+        is_sparse = False
 
         # query
         q = fluid.layers.data(
@@ -162,7 +162,7 @@ class TestPSPassWithBow(unittest.TestCase):
 
         role = fleet.UserDefinedRoleMaker(
             current_id=0,
-            role=role_maker.Role.SERVER,
+            role=role_maker.Role.WORKER,
             worker_num=2,
             server_endpoints=endpoints)
 
@@ -172,11 +172,13 @@ class TestPSPassWithBow(unittest.TestCase):
 
         strategy = paddle.distributed.fleet.DistributedStrategy()
         strategy.a_sync = True
-        strategy.a_sync_configs = {"k_steps": 100}
+        strategy.a_sync_configs = {"launch_barrier": False}
 
         optimizer = fleet.distributed_optimizer(optimizer, strategy)
         optimizer.minimize(loss)
 
+        fleet.shrink(10)
+
 
 if __name__ == '__main__':
     unittest.main()