diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc
index a96dec10866c012ed903b956747638848b63e23f..c63d65348880ebb4085d83059d9fead6456216d7 100644
--- a/paddle/fluid/operators/distributed/parameter_prefetch.cc
+++ b/paddle/fluid/operators/distributed/parameter_prefetch.cc
@@ -32,7 +32,7 @@ namespace paddle {
 namespace operators {
 namespace distributed {
 
-using Tensor = framework::Tensor;
+using LoDTensor = framework::LoDTensor;
 using LoDTensor = framework::LoDTensor;
 using SelectedRows = framework::SelectedRows;
 using DDim = framework::DDim;
@@ -117,6 +117,12 @@ static void MergeMultipleVarsIntoOneBySection(
   auto& id_tensor = scope->FindVar(id_name)->Get<framework::LoDTensor>();
   auto* out_tensor =
       scope->FindVar(out_name)->GetMutable<framework::LoDTensor>();
+
+  PADDLE_ENFORCE_GT(
+      out_tensor->numel(), 0,
+      "When calling this method, the LoDTensor's numel must larger than zero. "
+      "Please check LoDTensor::Resize has been called first.");
+
   auto* out_tensor_data = out_tensor->mutable_data<float>(id_tensor.place());
 
   bool is_on_cpu_place = true;
@@ -138,7 +144,7 @@ static void MergeMultipleVarsIntoOneBySection(
 
       auto row_numel = dims[1];
 
-      for (size_t i = 0; i < dims[0]; ++i) {
+      for (int64_t i = 0; i < dims[0]; ++i) {
         auto id = ids_in_this_section[i];
         auto origin_id = id + abs_sections[section_idx];
         auto& offsets = id_to_offset[origin_id];
@@ -172,8 +178,9 @@ void prefetch(const std::string& id_name, const std::string& out_name,
               const std::vector<std::string>& table_names,
               const std::vector<std::string>& epmap,
               const std::vector<int>& height_sections,
-              const framework::ExecutionContext& context) {
-  auto& local_scope = context.scope().NewScope();
+              const framework::ExecutionContext& context,
+              const framework::Scope& scope) {
+  auto& local_scope = scope.NewScope();
 
   platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
   auto& cpu_ctx = *pool.Get(platform::CPUPlace());
@@ -190,11 +197,11 @@ void prefetch(const std::string& id_name, const std::string& out_name,
     out_var_names.push_back(out_name + "@" + epmap[i]);
   }
 
-  auto& id_tensor = local_scope.FindVar(id_name)->Get<framework::LoDTensor>();
+  auto& id_tensor = scope.FindVar(id_name)->Get<framework::LoDTensor>();
   std::vector<int64_t> ids_vector;
   if (platform::is_cpu_place(id_tensor.place())) {
     auto* id_data = id_tensor.data<int64_t>();
-    for (size_t i = 0; i < id_tensor.numel(); ++i) {
+    for (int64_t i = 0; i < id_tensor.numel(); ++i) {
       ids_vector.push_back(id_data[i]);
     }
   } else {
@@ -202,7 +209,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
     PADDLE_THROW("paddle is not compiled with CUDA!");
 #else
     auto cpu_place = platform::CPUPlace();
-    framework::Tensor cpu_tensor;
+    framework::LoDTensor cpu_tensor;
     auto* cpu_tensor_data =
         cpu_tensor.mutable_data<int64_t>(id_tensor.dims(), cpu_place);
     auto stream =
@@ -246,8 +253,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
   MergeMultipleVarsIntoOneBySection(id_name, ids_vector, out_name,
                                     out_var_names, height_sections, splited_ids,
                                     context, &local_scope, &actual_ctx);
-
-  context.scope().DeleteScope(&local_scope);
+  scope.DeleteScope(&local_scope);
 }
 
 };  // namespace distributed
diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.h b/paddle/fluid/operators/distributed/parameter_prefetch.h
index 53b0fbfb51f60fa86351cca34fd1665c7802591b..2f850a0332256d458e79ed9da361c86eb8a2f780 100644
--- a/paddle/fluid/operators/distributed/parameter_prefetch.h
+++ b/paddle/fluid/operators/distributed/parameter_prefetch.h
@@ -27,7 +27,56 @@ void prefetch(const std::string& id_name, const std::string& out_name,
               const std::vector<std::string>& table_names,
               const std::vector<std::string>& epmap,
               const std::vector<int>& height_sections,
-              const framework::ExecutionContext& context);
+              const framework::ExecutionContext& context,
+              const framework::Scope& scope);
+
+template <typename T>
+void prefetch_with_reconstruct(const std::string& id_name,
+                               const std::string& out_name,
+                               const std::vector<std::string>& table_names,
+                               const std::vector<std::string>& epmap,
+                               const std::vector<int>& height_sections,
+                               const framework::ExecutionContext& context,
+                               const framework::Scope& scope,
+                               framework::LoDTensor* original) {
+  prefetch(id_name, out_name, table_names, epmap, height_sections, context,
+           scope);
+  auto& out = scope.FindVar(out_name)->Get<framework::LoDTensor>();
+  auto& ids = scope.FindVar(id_name)->Get<framework::LoDTensor>();
+  auto* original_value = original->data<T>();
+  auto* out_value = out.data<T>();
+  size_t original_width = original->numel() / original->dims()[0];
+
+  bool is_on_cpu_place = true;
+  if (!platform::is_cpu_place(ids.place())) {
+    is_on_cpu_place = false;
+  }
+  if (is_on_cpu_place) {
+    for (int64_t i = 0; i < ids.numel(); i++) {
+      const T* out_rows = out_value + original_width * i;
+      T* original_row =
+          original_value + original_width * ids.data<int64_t>()[i];
+      std::memcpy(original_row, out_rows, original_width * sizeof(T));
+    }
+  } else {
+#ifndef PADDLE_WITH_CUDA
+    PADDLE_THROW("paddle is not compiled with CUDA!");
+#else
+    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
+    auto& actual_ctx = *pool.Get(context.GetPlace());
+    for (int64_t i = 0; i < ids.numel(); i++) {
+      const T* out_rows = out_value + original_width * i;
+      T* original_row =
+          original_value + original_width * ids.data<int64_t>()[i];
+      auto stream =
+          static_cast<platform::CUDADeviceContext*>(&actual_ctx)->stream();
+      memory::Copy(boost::get<platform::CUDAPlace>(ids.place()), original_row,
+                   platform::CPUPlace(), out_rows, original_width * sizeof(T),
+                   stream);
+    }
+#endif
+  }
+}
 
 };  // namespace distributed
 };  // namespace operators
diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.cc b/paddle/fluid/operators/hierarchical_sigmoid_op.cc
index a807117115763486a58052a6240cdedba6af9ac8..6ca6f0bc04aa696852ed7338dcb4b88a49b2fc81 100644
--- a/paddle/fluid/operators/hierarchical_sigmoid_op.cc
+++ b/paddle/fluid/operators/hierarchical_sigmoid_op.cc
@@ -67,6 +67,11 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
     PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
     PADDLE_ENFORCE(ctx->HasOutput("PreOut"),
                    "Output(PreOut) should not be null.");
+    auto with_prefetch = ctx->Attrs().Get<bool>("remote_prefetch");
+    if (with_prefetch) {
+      PADDLE_ENFORCE(ctx->HasOutput("W_Out"),
+                     "Output(W_Out) should not be null.");
+    }
     const int64_t batch_size = ctx->GetInputDim("X")[0];
     std::vector<int64_t> output_shape({batch_size, 1});
     ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
@@ -95,7 +100,7 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
     AddInput("Label",
              "(LoDTensor, required), The labels of training data. It's a"
              "tensor with shape [N, 1].");
-    AddInput("PTable",
+    AddInput("PathTable",
              "(LoDTensor, optional), The Path Table from root to current word"
              "it should have shape like [N, L], L is the length of the Path")
         .AsDispensable();
@@ -119,8 +124,30 @@ class HierarchicalSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
               "[batch_size, code_length], where code_length represents the "
               "maximum path length from root to leaf nodes.")
         .AsIntermediate();
+    AddOutput(
+        "W_Out",
+        "(LoDTensor, optinal) using input 'W' as Output to make it mutable"
+        "When we are using prefetch")
+        .AsIntermediate();
     AddAttr<AttrType>("num_classes", "(int, optional), The number of classes")
         .SetDefault(2);
+    // for parameter prefetch
+    AddAttr<bool>("remote_prefetch", "").SetDefault(false);
+    AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
+    AddAttr<std::vector<int>>("height_sections",
+                              "Height for each output SelectedRows.")
+        .SetDefault(std::vector<int>({}));
+    AddAttr<std::vector<std::string>>(
+        "epmap",
+        "(string vector, default 127.0.0.1:6164)"
+        "Server endpoints in the order of input variables for mapping")
+        .SetDefault({});
+    AddAttr<std::vector<std::string>>(
+        "table_names",
+        "(string vector, the splited table names that will be fetched from "
+        "parameter server)"
+        "in the order of input variables for mapping")
+        .SetDefault({});
     AddComment(R"DOC(
 The hierarchical sigmoid operator organize the classes into a binary tree.
 At each node, a sigmoid function is used to calculate the probability of
@@ -189,23 +216,17 @@ class HierarchicalSigmoidGradOpGradVarTypeInference
                << " is set to SelectedRows";
       block->Var(w_grad_var_name)
           ->SetType(framework::proto::VarType::SELECTED_ROWS);
-      if (hasBias) {
-        VLOG(30) << "hierarchical_sigmoid_grad op "
-                 << framework::GradVarName("Bias") << " is set to SelectedRows";
-        block->Var(bias_grad_var_name)
-            ->SetType(framework::proto::VarType::SELECTED_ROWS);
-      }
     } else {
       VLOG(30) << "hierarchical_sigmoid_grad op " << framework::GradVarName("W")
                << " is set to LoDTensor";
       block->Var(w_grad_var_name)
           ->SetType(framework::proto::VarType::LOD_TENSOR);
-      if (hasBias) {
-        VLOG(30) << "hierarchical_sigmoid_grad op "
-                 << framework::GradVarName("Bias") << " is set to LoDTensor";
-        block->Var(bias_grad_var_name)
-            ->SetType(framework::proto::VarType::LOD_TENSOR);
-      }
+    }
+    if (hasBias) {
+      VLOG(30) << "hierarchical_sigmoid_grad op "
+               << framework::GradVarName("Bias") << " is set to LoDTensor";
+      block->Var(bias_grad_var_name)
+          ->SetType(framework::proto::VarType::LOD_TENSOR);
     }
     block->Var(w_grad_var_name)->SetDataType(block->Var("W")->GetDataType());
   }
diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h
index d212e6f8437e69e71c010b6af27a33ff5e39e1e1..1a7ca963010112bbcab69f1ceeb9cb8d19ca9b9e 100644
--- a/paddle/fluid/operators/hierarchical_sigmoid_op.h
+++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h
@@ -14,7 +14,9 @@ limitations under the License. */
 
 #pragma once
 #include <iostream>
+#include <iterator>
 #include <set>
+#include <string>
 #include <vector>
 #include "paddle/fluid/framework/mixed_vector.h"
 #include "paddle/fluid/framework/op_registry.h"
@@ -24,6 +26,10 @@ limitations under the License. */
 #include "paddle/fluid/operators/math/matrix_bit_code.h"
 #include "paddle/fluid/platform/transform.h"
 
+#ifdef PADDLE_WITH_DISTRIBUTE
+#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
+#endif
+
 namespace paddle {
 namespace operators {
 
@@ -34,8 +40,9 @@ using platform::Transform;
 
 static std::vector<int64_t> PathToRows(const framework::LoDTensor& path) {
   std::set<int64_t> rows;
+  const int64_t* paths = path.data<int64_t>();
   for (int64_t i = 0; i < path.numel(); ++i) {
-    int64_t row = path.data<int64_t>()[i];
+    int64_t row = paths[i];
     if (row < 0) {
       continue;
     }
@@ -49,13 +56,54 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
   void Compute(const framework::ExecutionContext& ctx) const override {
     auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
     auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
-    auto* path = ctx.Input<framework::LoDTensor>("PTable");
+    auto* path = ctx.Input<framework::LoDTensor>("PathTable");
     auto* code = ctx.Input<framework::LoDTensor>("PathCode");
     auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label"));
     auto* bias = ctx.Input<framework::LoDTensor>("Bias");
     auto* out = ctx.Output<framework::LoDTensor>("Out");
     auto* pre_out = ctx.Output<framework::LoDTensor>("PreOut");
     size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
+    // for remote prefetch
+
+    auto epmap = ctx.Attr<std::vector<std::string>>("epmap");
+    if (!epmap.empty()) {
+      // if epmap is not empty, then the parameter will be fetched from remote
+      // parameter
+      // server
+      auto height_sections = ctx.Attr<std::vector<int>>("height_sections");
+      auto table_names = ctx.Attr<std::vector<std::string>>("table_names");
+      std::vector<int64_t> real_rows = PathToRows(*path);
+      framework::Scope& local_scope = ctx.scope().NewScope();
+      auto* ids = local_scope.Var("Ids@Prefetch");
+      auto* x_tensor = ids->GetMutable<framework::LoDTensor>();
+
+      x_tensor->mutable_data<int64_t>(
+          framework::make_ddim({static_cast<int64_t>(real_rows.size()), 1}),
+          ctx.GetPlace());
+      // copy.
+
+      std::memcpy(x_tensor->data<int64_t>(), real_rows.data(),
+                  real_rows.size() * sizeof(int64_t));
+
+      framework::DDim w_dims = ctx.Input<Tensor>("W")->dims();
+      w_dims[0] = x_tensor->dims()[0];
+      auto* w_tensor =
+          local_scope.Var("W@Prefetch")->GetMutable<framework::LoDTensor>();
+      w_tensor->Resize(w_dims);
+
+#ifdef PADDLE_WITH_DISTRIBUTE
+      // w_Out is set to used by prefetch, never change it in other cases
+      auto* w_out = ctx.Output<framework::LoDTensor>("W_Out");
+      operators::distributed::prefetch_with_reconstruct<T>(
+          "Ids@Prefetch", "W@Prefetch", table_names, epmap, height_sections,
+          ctx, local_scope, w_out);
+#else
+      PADDLE_THROW(
+          "paddle is not compiled with distribute support, can not do "
+          "parameter prefetch!");
+#endif
+    }
+
     bool is_custom = false;
     if (path) {
       is_custom = true;
@@ -116,9 +164,8 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
   void Compute(const framework::ExecutionContext& ctx) const override {
     auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X"));
     auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W"));
-    auto* path = ctx.Input<framework::LoDTensor>("PTable");
+    auto* path = ctx.Input<framework::LoDTensor>("PathTable");
     auto* code = ctx.Input<framework::LoDTensor>("PathCode");
-    auto* bias = ctx.Input<framework::LoDTensor>("Bias");
     auto* in_grad =
         ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
     bool is_sparse = ctx.Attr<bool>("is_sparse");
@@ -173,15 +220,14 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
     }
     // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
     // be consistent with the clipping in forward.
-
+    auto* bias_grad =
+        ctx.Output<framework::LoDTensor>(framework::GradVarName("Bias"));
+    if (bias_grad) {
+      bias_grad->mutable_data<T>(ctx.GetPlace());
+      zero(dev_ctx, bias_grad, static_cast<T>(0.0));
+      bit_code->AddGrad(pre_out_grad, bias_grad);
+    }
     if (!is_sparse) {
-      auto* bias_grad =
-          ctx.Output<framework::LoDTensor>(framework::GradVarName("Bias"));
-      if (bias_grad) {
-        bias_grad->mutable_data<T>(ctx.GetPlace());
-        zero(dev_ctx, bias_grad, static_cast<T>(0.0));
-        bit_code->AddGrad(pre_out_grad, bias_grad);
-      }
       auto* w_grad =
           ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
       w_grad->mutable_data<T>(ctx.GetPlace());
@@ -200,21 +246,6 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
 
       w_grad_value->mutable_data<T>(temp_dim, ctx.GetPlace());
       zero(dev_ctx, w_grad_value, static_cast<T>(0.0));
-      auto* bias_grad =
-          ctx.Output<framework::SelectedRows>(framework::GradVarName("Bias"));
-      if (bias_grad) {
-        bias_grad->set_rows(real_rows);
-        // build ids -> rows index map
-        bias_grad->SyncIndex();
-        bias_grad->set_height(bias->dims()[0]);
-        auto* bias_grad_value = bias_grad->mutable_value();
-        std::vector<int64_t> dims = {static_cast<int64_t>(real_rows.size()),
-                                     bias->dims()[1]};
-        bias_grad_value->mutable_data<T>(framework::make_ddim(dims),
-                                         ctx.GetPlace());
-        zero(dev_ctx, bias_grad_value, static_cast<T>(0.0));
-        bit_code->AddGrad(pre_out_grad, bias_grad);
-      }
       bit_code->MulGradWeight(pre_out_grad, w_grad, in);
     }
     bit_code->MulGradError(pre_out_grad, w, in_grad);
diff --git a/paddle/fluid/operators/lookup_table_op.cu b/paddle/fluid/operators/lookup_table_op.cu
index 6a0d6bad512fe7cc15e60ed25028bc3cbbbca2ab..fd15539f7b6727496988c9b13d0d2551659a420a 100644
--- a/paddle/fluid/operators/lookup_table_op.cu
+++ b/paddle/fluid/operators/lookup_table_op.cu
@@ -92,7 +92,8 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
 // server
 #ifdef PADDLE_WITH_DISTRIBUTE
       operators::distributed::prefetch(id_name, out_name, table_names, epmap,
-                                       height_sections, context);
+                                       height_sections, context,
+                                       context.scope());
 #else
       PADDLE_THROW(
           "paddle is not compiled with distribute support, can not do "
diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h
index 3a73a7637c6d7d3eff7443802a4a52be9149e0ef..a7d0fd4856edc74237151c64f286d468ad86e7ca 100644
--- a/paddle/fluid/operators/lookup_table_op.h
+++ b/paddle/fluid/operators/lookup_table_op.h
@@ -59,7 +59,8 @@ class LookupTableKernel : public framework::OpKernel<T> {
 // server
 #ifdef PADDLE_WITH_DISTRIBUTE
       operators::distributed::prefetch(id_name, out_name, table_names, epmap,
-                                       height_sections, context);
+                                       height_sections, context,
+                                       context.scope());
 #else
       PADDLE_THROW(
           "paddle is not compiled with distribute support, can not do "
diff --git a/paddle/fluid/operators/math/matrix_bit_code.cc b/paddle/fluid/operators/math/matrix_bit_code.cc
index d55e832cc2d9a4a5e2cb7fe5cf451a1205601951..d6f51c6e5c693becb14ff0bac0088bb9dc2b2f55 100644
--- a/paddle/fluid/operators/math/matrix_bit_code.cc
+++ b/paddle/fluid/operators/math/matrix_bit_code.cc
@@ -84,41 +84,6 @@ void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor &tmat,
   code_table_.apply_visitor(func);
 }
 
-template <typename T>
-struct MatrixBitCodeFunctorSelectedRowsAddGrad
-    : public boost::static_visitor<void> {
-  const framework::Tensor &tmat_;
-  framework::SelectedRows *vec_;
-
-  MatrixBitCodeFunctorSelectedRowsAddGrad(const framework::Tensor &tmat,
-                                          framework::SelectedRows *vec)
-      : tmat_(tmat), vec_(vec) {}
-
-  template <typename CodeTable>
-  void operator()(const CodeTable &code_table) {
-    size_t batch_size = tmat_.dims()[0];
-    size_t width = tmat_.dims()[1];
-    auto *vec_data = vec_->mutable_value()->template data<T>();
-    auto *tmat_data = tmat_.data<T>();
-    for (size_t i = 0; i < batch_size; ++i) {
-      auto code = code_table.get_code(i);
-      int code_length = code.get_length();
-      for (int j = 0; j < code_length; ++j) {
-        size_t index = code.calc_index(j);
-        int64_t row_index = vec_->GetIndexFromId(static_cast<int64_t>(index));
-        vec_data[row_index] += tmat_data[i * width + j];
-      }
-    }
-  }
-};
-
-template <typename T>
-void MatrixBitCodeFunctor<T>::AddGrad(const framework::Tensor &tmat,
-                                      framework::SelectedRows *vec) {
-  MatrixBitCodeFunctorSelectedRowsAddGrad<T> func(tmat, vec);
-  code_table_.apply_visitor(func);
-}
-
 template <typename T>
 struct MatrixBitCodeFunctorSum : public boost::static_visitor<void> {
   const framework::Tensor &tmat_;
diff --git a/paddle/fluid/operators/math/matrix_bit_code.h b/paddle/fluid/operators/math/matrix_bit_code.h
index 01e4889d34ad6e409f1b8a9c4bf783800187e863..c399cb5d44aaa50fab00fd170c021c8c70eee990 100644
--- a/paddle/fluid/operators/math/matrix_bit_code.h
+++ b/paddle/fluid/operators/math/matrix_bit_code.h
@@ -124,11 +124,12 @@ class SimpleCode {
 template <typename T>
 class CustomCode {
  public:
-  CustomCode(const framework::Tensor& ptable, const framework::Tensor& pcode,
-             const int64_t* ids, int index) {
-    seq_len_ = ptable.dims()[1];
-    ptable_data_ = ptable.data<T>() + seq_len_ * index;
-    pcode_data_ = pcode.data<T>() + seq_len_ * index;
+  CustomCode(const framework::Tensor& path_table,
+             const framework::Tensor& path_code, const int64_t* ids,
+             int index) {
+    seq_len_ = path_table.dims()[1];
+    path_table_data_ = path_table.data<T>() + seq_len_ * index;
+    path_code_data_ = path_code.data<T>() + seq_len_ * index;
   }
   /**
    * Here the id of root should be 1 rather than 0, thus the encoding of class c
@@ -139,25 +140,25 @@ class CustomCode {
    * Binary classification path is the suffixes of encoding, thus leave out the
    * left most bit in calc_bit.
    */
-  size_t calc_index(int bit) const { return ptable_data_[bit]; }
-  bool calc_bit(int bit) const { return pcode_data_[bit]; }
+  size_t calc_index(int bit) const { return path_table_data_[bit]; }
+  bool calc_bit(int bit) const { return path_code_data_[bit]; }
 
   // NOTE: this function is not thread-safe.
   int get_length() const {
     if (length_ < 0) {
       auto len = seq_len_;
-      length_ =
-          static_cast<int>(std::find_if(ptable_data_, ptable_data_ + len,
-                                        [](const T& val) { return val < 0; }) -
-                           ptable_data_);
+      length_ = static_cast<int>(
+          std::find_if(path_table_data_, path_table_data_ + len,
+                       [](const T& val) { return val < 0; }) -
+          path_table_data_);
     }
     return length_;
   }
 
  private:
   int64_t seq_len_;
-  const T* ptable_data_;
-  const T* pcode_data_;
+  const T* path_table_data_;
+  const T* path_code_data_;
   mutable int length_{-1};
 };
 
@@ -181,9 +182,9 @@ class SimpleCodeTable {
 template <typename T>
 class CustomCodeTable {
  public:
-  CustomCodeTable(const framework::Tensor& ptable,
-                  const framework::Tensor& pcode, const int64_t* ids)
-      : ptable_(ptable), pcode_(pcode), ids_(ids) {}
+  CustomCodeTable(const framework::Tensor& path_table,
+                  const framework::Tensor& path_code, const int64_t* ids)
+      : ptable_(path_table), pcode_(path_code), ids_(ids) {}
 
   CustomCode<T> get_code(int64_t code) const {
     return CustomCode<T>(ptable_, pcode_, ids_, code);
@@ -210,11 +211,11 @@ class MatrixBitCodeFunctor {
         ids_(ids),
         code_table_(SimpleCodeTable(num_classes, ids)) {}
 
-  MatrixBitCodeFunctor(const framework::Tensor& ptable,
-                       const framework::Tensor& pcode, const int64_t* ids)
-      : num_classes_(static_cast<size_t>(ptable.dims()[1])),
+  MatrixBitCodeFunctor(const framework::Tensor& path_table,
+                       const framework::Tensor& path_code, const int64_t* ids)
+      : num_classes_(static_cast<size_t>(path_table.dims()[1])),
         ids_(ids),
-        code_table_(CustomCodeTable<int64_t>(ptable, pcode, ids)) {}
+        code_table_(CustomCodeTable<int64_t>(path_table, path_code, ids)) {}
   /* For j < code_length
        tmat(i, j) += vec(0, index(i, j))
   */
@@ -225,11 +226,6 @@ class MatrixBitCodeFunctor {
   */
   void AddGrad(const framework::Tensor& tmat, framework::Tensor* vec);
 
-  /* For selected rows For j < code_length
-       vec(0, index(i, j)) += tmat(i, j)
-  */
-  void AddGrad(const framework::Tensor& tmat, framework::SelectedRows* vec);
-
   /* For j < code_length
     sum(i, 0) = \sum_j bit(i, j) * tmat(i, j)
   */
diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc
index 784e07b5bd7f3836f3515c789f998ba1bf30f6e8..256da34912560ddf1f7e430e8543efe00e5885bc 100644
--- a/paddle/fluid/operators/nce_op.cc
+++ b/paddle/fluid/operators/nce_op.cc
@@ -153,6 +153,24 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker {
     AddAttr<bool>("is_sparse", "(boolean, default false) Sparse update.")
         .SetDefault(false);
 
+    // for parameter prefetch
+    AddAttr<bool>("remote_prefetch", "").SetDefault(false);
+    AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
+    AddAttr<std::vector<int>>("height_sections",
+                              "Height for each output SelectedRows.")
+        .SetDefault(std::vector<int>({}));
+    AddAttr<std::vector<std::string>>(
+        "epmap",
+        "(string vector, default 127.0.0.1:6164)"
+        "Server endpoints in the order of input variables for mapping")
+        .SetDefault({});
+    AddAttr<std::vector<std::string>>(
+        "table_names",
+        "(string vector, the splited table names that will be fetched from "
+        "parameter server)"
+        "in the order of input variables for mapping")
+        .SetDefault({});
+
     AddAttr<std::vector<int>>("custom_neg_classes",
                               "This attribute only be used in unitest. Classes "
                               "in this list wiil be used as negative classes "
@@ -222,24 +240,20 @@ class NCEOpGradVarTypeInference : public framework::VarTypeInference {
   void operator()(const framework::OpDesc &op_desc,
                   framework::BlockDesc *block) const override {
     auto weight_grad = op_desc.Output(framework::GradVarName("Weight")).front();
-    auto bias_grad = op_desc.Output(framework::GradVarName("Bias")).front();
 
     auto attr = op_desc.GetAttr("is_sparse");
     bool is_sparse = boost::get<bool>(attr);
     if (is_sparse) {
-      VLOG(3) << "nce_op_grad op " << weight_grad << " and " << bias_grad
+      VLOG(3) << "nce_op_grad op " << weight_grad << " and "
               << " is set to SelectedRows";
       block->Var(weight_grad)
           ->SetType(framework::proto::VarType::SELECTED_ROWS);
-      block->Var(bias_grad)->SetType(framework::proto::VarType::SELECTED_ROWS);
     } else {
-      VLOG(3) << "nce_op_grad op " << weight_grad << " and " << bias_grad
+      VLOG(3) << "nce_op_grad op " << weight_grad << " and "
               << " is set to LoDTensor";
       block->Var(weight_grad)->SetType(framework::proto::VarType::LOD_TENSOR);
-      block->Var(bias_grad)->SetType(framework::proto::VarType::LOD_TENSOR);
     }
     block->Var(weight_grad)->SetDataType(block->Var("Input")->GetDataType());
-    block->Var(bias_grad)->SetDataType(block->Var("Input")->GetDataType());
   }
 };
 
diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h
index f2ca6ec247fd1ea09b707c2eaaad0548c8aa5757..2c97eef096eb3d23273e362e658cb1b5fc808609 100644
--- a/paddle/fluid/operators/nce_op.h
+++ b/paddle/fluid/operators/nce_op.h
@@ -15,8 +15,10 @@ limitations under the License. */
 #pragma once
 
 #include <math.h>
+#include <iterator>
 #include <random>
 #include <set>
+#include <string>
 #include <vector>
 #include "paddle/fluid/framework/eigen.h"
 #include "paddle/fluid/framework/op_registry.h"
@@ -24,6 +26,10 @@ limitations under the License. */
 #include "paddle/fluid/operators/math/sampler.h"
 #include "unsupported/Eigen/CXX11/Tensor"
 
+#ifdef PADDLE_WITH_DISTRIBUTE
+#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
+#endif
+
 namespace paddle {
 namespace operators {
 
@@ -43,7 +49,6 @@ void PrepareSamples(const framework::ExecutionContext &context,
   auto label = context.Input<Tensor>("Label");
   const int64_t *label_data = label->data<int64_t>();
   auto label_dims = label->dims();
-  //  int num_total_classes = context.Attr<int>("num_total_classes");
   // for unitest
   std::vector<int> custom_neg_classes =
       context.Attr<std::vector<int>>("custom_neg_classes");
@@ -144,15 +149,82 @@ class NCEKernel : public framework::OpKernel<T> {
     }
     // forward mul
     auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
-    auto weight_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
-    for (int64_t i = 0; i < sample_labels->numel(); ++i) {
-      Eigen::Tensor<T, 0, Eigen::RowMajor, Eigen::DenseIndex> result =
-          (input_mat.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
-           weight_mat.chip(sample_labels_data[i], 0))
-              .sum();
-      sample_out_data[i] += result(0);
-      sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i])));
+
+    // for remote prefetch
+    auto epmap = context.Attr<std::vector<std::string>>("epmap");
+
+    if (!epmap.empty()) {
+      // if epmap is not empty, then the parameter will be fetched from remote
+      // parameter
+      // server
+
+      std::vector<int64_t> labels;
+      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
+        labels.push_back(sample_labels_data[i]);
+      }
+      std::set<T> st(labels.begin(), labels.end());
+      labels.assign(st.begin(), st.end());
+
+      framework::Scope &local_scope = context.scope().NewScope();
+
+      auto height_sections = context.Attr<std::vector<int>>("height_sections");
+      auto table_names = context.Attr<std::vector<std::string>>("table_names");
+
+      auto *ids = local_scope.Var("Ids@Prefetch");
+      auto *x_tensor = ids->GetMutable<framework::LoDTensor>();
+      x_tensor->mutable_data<int64_t>(
+          framework::make_ddim({static_cast<int64_t>(labels.size()), 1}),
+          context.GetPlace());
+      // copy.
+      std::memcpy(x_tensor->data<int64_t>(), labels.data(),
+                  labels.size() * sizeof(int64_t));
+
+      std::vector<int> w_dims = paddle::framework::vectorize2int(
+          context.Input<Tensor>("Weight")->dims());
+      w_dims[0] = static_cast<int>(labels.size());
+
+      auto *w_tensor = local_scope.Var("Weight@Prefetch")
+                           ->GetMutable<framework::LoDTensor>();
+      w_tensor->Resize(framework::make_ddim(w_dims));
+
+#ifdef PADDLE_WITH_DISTRIBUTE
+      operators::distributed::prefetch("Ids@Prefetch", "Weight@Prefetch",
+                                       table_names, epmap, height_sections,
+                                       context, local_scope);
+#else
+      PADDLE_THROW(
+          "paddle is not compiled with distribute support, can not do "
+          "parameter prefetch!");
+#endif
+
+      auto weight_mat = EigenMatrix<T>::From(
+          (local_scope.Var("Weight@Prefetch")->Get<framework::LoDTensor>()));
+      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
+        std::vector<int64_t>::iterator it =
+            std::find(labels.begin(), labels.end(), sample_labels_data[i]);
+        int idx = std::distance(labels.begin(), it);
+
+        Eigen::Tensor<T, 0, Eigen::RowMajor, Eigen::DenseIndex> result =
+            (input_mat.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
+             weight_mat.chip(idx, 0))
+                .sum();
+        sample_out_data[i] += result(0);
+        sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i])));
+      }
+      context.scope().DeleteScope(&local_scope);
+    } else {
+      auto weight_mat =
+          EigenMatrix<T>::From(*(context.Input<Tensor>("Weight")));
+      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
+        Eigen::Tensor<T, 0, Eigen::RowMajor, Eigen::DenseIndex> result =
+            (input_mat.chip(static_cast<int>(i / sample_labels->dims()[1]), 0) *
+             weight_mat.chip(sample_labels_data[i], 0))
+                .sum();
+        sample_out_data[i] += result(0);
+        sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i])));
+      }
     }
+
     // forward cost
     for (int64_t i = 0; i < sample_labels->dims()[0]; ++i) {
       out_data[i] = 0;
@@ -240,18 +312,19 @@ class NCEGradKernel : public framework::OpKernel<T> {
       sample_grad_data[i] *= d_out_data[sample_idx];
     }
 
+    // get d_bias
+    auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
+    if (d_bias != nullptr) {
+      T *d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
+      std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
+      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
+        d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
+      }
+    }
+
     bool is_sparse = context.Attr<bool>("is_sparse");
 
     if (!is_sparse) {
-      // get d_bias
-      auto d_bias = context.Output<Tensor>(framework::GradVarName("Bias"));
-      if (d_bias != nullptr) {
-        T *d_bias_data = d_bias->mutable_data<T>(context.GetPlace());
-        std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0);
-        for (int64_t i = 0; i < sample_labels->numel(); ++i) {
-          d_bias_data[sample_labels_data[i]] += sample_grad_data[i];
-        }
-      }
       // get d_w
       auto d_w = context.Output<Tensor>(framework::GradVarName("Weight"));
       if (d_w != nullptr) {
@@ -273,34 +346,6 @@ class NCEGradKernel : public framework::OpKernel<T> {
       std::set<T> st(labels.begin(), labels.end());
       labels.assign(st.begin(), st.end());
 
-      auto *bias_var = context.InputVar("Bias");
-      DDim bias_dim;
-      if (bias_var->IsType<LoDTensor>()) {
-        bias_dim = context.Input<LoDTensor>("Bias")->dims();
-      } else if (bias_var->IsType<SelectedRows>()) {
-        auto *table_t = context.Input<SelectedRows>("Bias");
-        bias_dim = table_t->value().dims();
-      } else {
-        PADDLE_THROW(
-            "The parameter Bias of a NCE_OP "
-            "must be either LoDTensor or SelectedRows");
-      }
-
-      auto d_bias =
-          context.Output<SelectedRows>(framework::GradVarName("Bias"));
-      d_bias->set_rows(labels);
-      d_bias->set_height(bias_dim[0]);
-
-      d_bias->mutable_value()->Resize(
-          {static_cast<int64_t>(labels.size()), bias_dim[1]});
-      T *d_bias_data =
-          d_bias->mutable_value()->mutable_data<T>(context.GetPlace());
-      std::fill(d_bias_data, d_bias_data + labels.size(), 0.0);
-      for (int64_t i = 0; i < sample_labels->numel(); ++i) {
-        d_bias_data[d_bias->Index(sample_labels_data[i])] +=
-            sample_grad_data[i];
-      }
-
       auto *table_var = context.InputVar("Weight");
       DDim table_dim;
       if (table_var->IsType<LoDTensor>()) {
diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py
index 9572fcb385823eab16d5c44fd56c680e577c8f04..615a35ba916f813399dc21a87646884b3d01081e 100644
--- a/python/paddle/fluid/layers/nn.py
+++ b/python/paddle/fluid/layers/nn.py
@@ -26,7 +26,7 @@ from ..initializer import Normal, Constant
 from ..framework import Variable, OpProtoHolder
 from ..param_attr import ParamAttr
 from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
-from .tensor import concat
+from .tensor import concat, assign
 from . import utils
 from .. import unique_name
 from functools import reduce
@@ -340,9 +340,7 @@ def embedding(input,
     """
 
     helper = LayerHelper('embedding', **locals())
-    remote_prefetch = False
-    if os.environ.get('PADDLE_ENABLE_REMOTE_PREFETCH'):
-        remote_prefetch = True
+    remote_prefetch = is_sparse and (not is_distributed)
     if remote_prefetch:
         assert is_sparse is True and is_distributed is False
     w = helper.create_parameter(
@@ -5032,12 +5030,18 @@ def nce(input,
     else:
         num_neg_samples = int(num_neg_samples)
 
+    remote_prefetch = is_sparse
+    print(
+        "With sparse mode, if your models has only small parameter prefetch may cause speed down"
+    )
+
     attrs = {
         'num_total_classes': int(num_total_classes),
         'num_neg_samples': num_neg_samples,
         'seed': seed,
         'sampler': sampler,
-        'is_sparse': is_sparse
+        'is_sparse': is_sparse,
+        'remote_prefetch': remote_prefetch
     }
 
     helper.append_op(
@@ -5147,7 +5151,10 @@ def hsigmoid(input,
         pass
 
     weights = None
-
+    remote_prefetch = is_sparse
+    print(
+        "With sparse mode, if your models has only small parameter prefetch may cause speed down"
+    )
     if not is_custom:
         weights = helper.create_parameter(
             attr=helper.param_attr,
@@ -5163,7 +5170,7 @@ def hsigmoid(input,
     inputs = {
         "X": input,
         "W": weights,
-        "PTable": path_table,
+        "PathTable": path_table,
         "PathCode": path_code,
         "Label": label
     }
@@ -5186,9 +5193,13 @@ def hsigmoid(input,
         type="hierarchical_sigmoid",
         inputs=inputs,
         outputs={"Out": out,
-                 "PreOut": pre_out},
-        attrs={"num_classes": num_classes,
-               "is_sparse": is_sparse})
+                 "PreOut": pre_out,
+                 "W_Out": weights},
+        attrs={
+            "num_classes": num_classes,
+            "is_sparse": is_sparse,
+            "remote_prefetch": remote_prefetch
+        })
     return out
 
 
@@ -7684,7 +7695,7 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
 
     Examples:
 
-        .. code-block:: python
+    .. code-block:: python
 
             x = fluid.layers.data(name="x", shape=[2,3,16,16], dtype="float32")
             y = fluid.layers.brelu(x, t_min=1.0, t_max=20.0)
diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt
index 344130499506fbbecdda6551dc9abe3fca22d153..ec8b19c7ba07a9e57a32277ff3fc34b0ea25a819 100644
--- a/python/paddle/fluid/tests/unittests/CMakeLists.txt
+++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt
@@ -21,6 +21,8 @@ if(NOT WITH_DISTRIBUTE)
     LIST(REMOVE_ITEM TEST_OPS test_dist_simnet_bow)
     LIST(REMOVE_ITEM TEST_OPS test_dist_mnist_batch_merge)
     LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification)
+    LIST(REMOVE_ITEM TEST_OPS test_nce_remote_table_op)
+    LIST(REMOVE_ITEM TEST_OPS test_hsigmoid_remote_table_op)
 endif(NOT WITH_DISTRIBUTE)
 
 if (NOT ${WITH_GPU})
@@ -32,7 +34,6 @@ endif()
 list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
 list(REMOVE_ITEM TEST_OPS test_modified_huber_loss_op) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5184
 list(REMOVE_ITEM TEST_OPS test_lstm_unit_op) # # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5185
-list(REMOVE_ITEM TEST_OPS test_nce) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/7778
 list(REMOVE_ITEM TEST_OPS test_recurrent_op) # FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/6152
 list(REMOVE_ITEM TEST_OPS test_cond_op) # FIXME(qijun): https://github.com/PaddlePaddle/Paddle/issues/5101#issuecomment-339814957
 
diff --git a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py
index d9ad4e2e2c7b8d0a99d917495fbc8efc6cbd188d..3d1ce6b27c935ddca0f2f5fb377e69b571e3714c 100644
--- a/python/paddle/fluid/tests/unittests/test_dist_transpiler.py
+++ b/python/paddle/fluid/tests/unittests/test_dist_transpiler.py
@@ -14,14 +14,15 @@
 
 from __future__ import print_function
 
+import traceback
 import math
+import collections
 
+import six
 import unittest
+import numpy as np
+
 import paddle.fluid as fluid
-from paddle.fluid.transpiler.distribute_transpiler import delete_ops
-import traceback
-import collections
-import six
 
 
 class TranspilerTest(unittest.TestCase):
@@ -520,7 +521,7 @@ class TestLocalLookupTable(TestDistLookupTableBase):
             'split_selected_rows', 'send', 'sequence_pool_grad',
             'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
             'sum', 'split_selected_rows', 'send', 'send_barrier', 'recv',
-            'recv', 'recv', 'recv', 'fetch_barrier', 'concat', 'concat'
+            'recv', 'fetch_barrier'
         ]
         self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
 
@@ -560,7 +561,7 @@ class TestDistLookupTable(TestDistLookupTableBase):
             'lookup_table_grad', 'split_selected_rows', 'send',
             'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
             'lookup_table_grad', 'sum', 'split_ids', 'send', 'send_barrier',
-            'recv', 'recv', 'recv', 'fetch_barrier', 'concat'
+            'recv', 'recv', 'fetch_barrier'
         ]
         self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
         startup_ops = [
@@ -607,8 +608,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase):
             'send', 'concat_grad', 'sequence_pool_grad', 'lookup_table_grad',
             'split_selected_rows', 'send', 'sequence_pool_grad',
             'lookup_table_grad', 'sequence_pool_grad', 'lookup_table_grad',
-            'sum', 'split_selected_rows', 'send', 'recv', 'recv', 'recv',
-            'recv', 'concat', 'concat'
+            'sum', 'split_selected_rows', 'send', 'recv', 'recv'
         ]
         self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
 
@@ -648,8 +648,7 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
             'mul_grad', 'send', 'concat_grad', 'sequence_pool_grad',
             'lookup_table_grad', 'split_selected_rows', 'send',
             'sequence_pool_grad', 'lookup_table_grad', 'sequence_pool_grad',
-            'lookup_table_grad', 'sum', 'split_ids', 'send', 'recv', 'recv',
-            'recv', 'concat'
+            'lookup_table_grad', 'sum', 'split_ids', 'send', 'recv', 'recv'
         ]
         self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
         startup_ops = [
@@ -824,5 +823,142 @@ class TestRemoteLookupTable(TestDistLookupTableBase):
         self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
 
 
+# test for remote prefetch
+class TestRemoteNce(TestDistLookupTableBase):
+    def network_with_table(self, is_sparse, is_distributed):
+
+        num_total_classes = 20
+        sampler = "uniform"
+        nid_freq_arr = np.random.dirichlet(np.ones(20) * 1000).astype('float32')
+
+        input = fluid.layers.data(name="input", shape=[10], dtype="float32")
+        label = fluid.layers.data(name="label", shape=[1], dtype="int64")
+
+        w_param = fluid.default_main_program().global_block().create_parameter(
+            shape=[num_total_classes, 10],
+            dtype='float32',
+            name='nce_w',
+            initializer=fluid.initializer.ConstantInitializer())
+        b_param = fluid.default_main_program().global_block().create_parameter(
+            shape=[num_total_classes, 1],
+            dtype='float32',
+            name='nce_b',
+            initializer=fluid.initializer.ConstantInitializer())
+
+        cost = fluid.layers.nce(input=input,
+                                label=label,
+                                num_total_classes=num_total_classes,
+                                sampler=sampler,
+                                custom_dist=nid_freq_arr.tolist(),
+                                sample_weight=None,
+                                param_attr='nce_w',
+                                bias_attr='nce_b',
+                                seed=1,
+                                num_neg_samples=5,
+                                is_sparse=is_sparse)
+        avg_cost = fluid.layers.mean(cost)
+        # optimizer
+        optimizer = fluid.optimizer.Adam(learning_rate=0.003)
+        optimizer.minimize(avg_cost)
+
+    def net_conf(self):
+        import os
+        os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
+        self.network_with_table(is_sparse=True, is_distributed=False)
+
+    def transpiler_test_impl(self):
+        trainer, _ = self.get_trainer()
+
+        out_vars = ["nce_w"]
+        in_vars = ["nce_b"]
+
+        recv_var_names = []
+
+        for op in trainer.blocks[0].ops:
+            if op.type == "recv":
+                for var in op.output("Out"):
+                    recv_var_names.append(var)
+
+        for out_var in out_vars:
+            self.assertFalse(out_var in recv_var_names)
+        for in_var in in_vars:
+            self.assertTrue(in_var in recv_var_names)
+
+
+# test for remote prefetch
+class TestRemoteHsigmoid(TestDistLookupTableBase):
+    def network_with_table(self, is_sparse, is_distributed):
+
+        num_total_classes = 3
+
+        input = fluid.layers.data(name="input", shape=[1], dtype="float32")
+        label = fluid.layers.data(name="label", shape=[1], dtype="int64")
+        path_table = fluid.layers.data(
+            name='path_table', shape=[3], dtype='int64')
+        path_code = fluid.layers.data(
+            name='path_code', shape=[3], dtype='int64')
+        w_param = fluid.default_main_program().global_block().create_parameter(
+            shape=[num_total_classes, 10],
+            dtype='float32',
+            name='hs_w',
+            initializer=fluid.initializer.ConstantInitializer())
+        b_param = fluid.default_main_program().global_block().create_parameter(
+            shape=[3, 1],
+            dtype='float32',
+            name='hs_b',
+            initializer=fluid.initializer.ConstantInitializer())
+
+        emb = fluid.layers.embedding(
+            input=input,
+            is_sparse=is_sparse,
+            size=[3, 3],
+            param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
+                scale=1 / math.sqrt(num_total_classes))))
+
+        cost = fluid.layers.hsigmoid(
+            input=emb,
+            label=label,
+            num_classes=num_total_classes,
+            path_table=path_table,
+            path_code=path_code,
+            is_custom=True,
+            is_sparse=is_sparse)
+        avg_cost = fluid.layers.mean(cost)
+        # optimizer
+        optimizer = fluid.optimizer.SGD(learning_rate=0.003)
+        optimizer.minimize(avg_cost)
+
+    def net_conf(self):
+        import os
+        os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
+        self.network_with_table(is_sparse=True, is_distributed=False)
+
+    def transpiler_test_impl(self):
+        trainer, _ = self.get_trainer()
+        params_to_check = list()
+        for op in trainer.blocks[0].ops:
+            if op.type == "hierarchical_sigmoid":
+                params_to_check = [op.input("W")[0], op.input("Bias")[0]]
+                for name in ["epmap", "table_names", "epmap"]:
+                    assert op.has_attr(name)
+                    if name == "epmap":
+                        assert op.attr(name)[0] == u'127.0.0.1:6174'
+                    elif name == "table_names":
+                        assert op.attr(name)[0] == u'hierarchical_sigmoid_0.w_0'
+                    else:
+                        assert op.attr(name) == 3
+            elif op.type == "lookup_table":
+                params_to_check.append(op.input("W")[0])
+            else:
+                pass
+        op_count = 0
+        for op in trainer.blocks[0].ops:
+            if op.type == "recv":
+                assert len(op.output("Out")) == 1
+                assert op.output("Out")[0] == u'hierarchical_sigmoid_0.b_0'
+                op_count += 1
+        assert op_count == 1
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py
index 2a6c93f75fad53440a2db64e4f34c9a5c22c654e..8ed5074dc2626ff58fc65d8af1340e260c029572 100644
--- a/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py
+++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_op.py
@@ -185,7 +185,7 @@ class TestHSigmoidOpSparse(OpTest):
         self.inputs = {
             'X': x,
             'W': w,
-            'PTable': path_table,
+            'PathTable': path_table,
             'PathCode': path_code,
             'Label': label,
             'Bias': bias
@@ -287,7 +287,7 @@ class TestHSigmoidOpWithCostumTree(OpTest):
         self.inputs = {
             'X': x,
             'W': w,
-            'PTable': path_table,
+            'PathTable': path_table,
             'PathCode': path_code,
             'Label': label,
             'Bias': bias
@@ -324,7 +324,7 @@ class TestHSigmoidOpWithCostumTreeWithoutBias(OpTest):
         self.inputs = {
             'X': x,
             'W': w,
-            'PTable': path_table,
+            'PathTable': path_table,
             'PathCode': path_code,
             'Label': label,
         }
diff --git a/python/paddle/fluid/tests/unittests/test_hsigmoid_remote_table_op.py b/python/paddle/fluid/tests/unittests/test_hsigmoid_remote_table_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..da343dd503a62e83f431dd0ffb02a7e70be7d0d5
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_hsigmoid_remote_table_op.py
@@ -0,0 +1,269 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import os
+import signal
+import time
+import unittest
+from multiprocessing import Process
+
+import numpy as np
+import paddle.fluid as fluid
+import paddle.fluid.core as core
+from paddle.fluid.op import Operator
+from paddle.fluid.framework import Program, program_guard
+
+
+def run_pserver(pserver_id, use_cuda, sync_mode):
+    scope = fluid.core.Scope()
+    program = Program()
+    with fluid.scope_guard(scope):
+        with program_guard(program, startup_program=Program()):
+            # create table parameter in scope
+            place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
+            # create and initialize Param Variable
+            param = scope.var('table').get_tensor()
+
+            param_array = np.ones((5, 8)).astype("float32")
+            for i in range(len(param_array)):
+                param_array[i] *= param_array[i] * i + pserver_id * 10 + 1
+            param.set(param_array, place)
+
+            optimize_block = program._create_block(program.global_block().idx)
+            program.global_block().append_op(
+                type="listen_and_serv",
+                inputs={'X': []},
+                outputs={},
+                attrs={
+                    "optimize_blocks": [optimize_block],
+                    "endpoint": '127.0.0.1:0',
+                    "Fanin": 1,
+                    "sync_mode": True,
+                    "grad_to_block_id": []
+                })
+
+            exe = fluid.Executor(place)
+            exe.run(program)
+
+
+class TestListenAndServOp(unittest.TestCase):
+    def setUp(self):
+        self.ps_timeout = 5
+
+    def _start_pserver(self, pserver_id, use_cuda, sync_mode, pserver_func):
+        p = Process(target=pserver_func, args=(pserver_id, use_cuda, sync_mode))
+        p.daemon = True
+        p.start()
+        return p
+
+    def _wait_ps_ready(self, pid):
+        start_left_time = self.ps_timeout
+        sleep_time = 0.5
+        while True:
+            assert start_left_time >= 0, "wait ps ready failed"
+            time.sleep(sleep_time)
+            try:
+                # the listen_and_serv_op would touch a file which contains the listen port
+                # on the /tmp directory until it was ready to process all the RPC call.
+                os.stat("/tmp/paddle.%d.port" % pid)
+                return
+            except os.error:
+                start_left_time -= sleep_time
+
+    def _get_pserver_port(self, pid):
+        with open("/tmp/paddle.%d.port" % pid, 'r') as f:
+            port = int(f.read().strip())
+        return port
+
+    def _run_hsigmoid_op_one_pserver(self, place, port):
+        scope = fluid.core.Scope()
+        program = Program()
+        with fluid.scope_guard(scope):
+            with program_guard(program, startup_program=Program()):
+                x = scope.var('X').get_tensor()
+                x_array = np.random.random((4, 8)).astype("float32") * 2
+                x.set(x_array, place)
+                # create and initialize Param Variable
+                param = scope.var('W').get_tensor()
+                param_array = np.zeros((5, 8)).astype("float32") * 2
+                param.set(param_array, place)
+
+                path_table = scope.var('PathTable').get_tensor()
+                path_table_array = np.array(
+                    [(0, 2, -1, -1, -1), (0, 1, 2, -1, -1), (0, 1, 4, -1, -1),
+                     (0, 2, -1, -1, -1)]).astype(
+                         "int64"
+                     )  #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
+                path_table.set(path_table_array, place)
+
+                path_code = scope.var('PathCode').get_tensor()
+                path_code_array = np.array(
+                    [(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (1, 0, 0, -1, -1),
+                     (0, 1, -1, -1, -1)]).astype("int64")  #np.array to store 
+                path_code.set(path_code_array, place)
+
+                label = scope.var('Label').get_tensor()
+                label_array = np.array([0, 1, 4, 5])
+                label.set(label_array, place)
+
+                bias = scope.var('Bias').get_tensor()
+                bias_array = np.random.random((5, 1)).astype("float32")
+                bias.set(bias_array, place)
+
+                out = scope.var('Out').get_tensor()
+
+                pre_out = scope.var('PreOut').get_tensor
+
+                w_out = scope.var('W_Out').get_tensor()
+                w_out.set(param_array, place)
+
+                emaps = ['127.0.0.1:' + str(port)]
+                table_names = ['table']
+                height_sections = [2]
+
+                # create and run sgd operator
+                hsigmoid_op = Operator(
+                    "hierarchical_sigmoid",
+                    X='X',
+                    W='W',
+                    PathTable='PathTable',
+                    PathCode='PathCode',
+                    Label='Label',
+                    Bias='Bias',
+                    Out='Out',
+                    PreOut='PreOut',
+                    W_Out='W_Out',
+                    remote_prefetch=True,
+                    epmap=emaps,
+                    table_names=table_names,
+                    height_sections=height_sections)
+
+                hsigmoid_op.run(scope, place)
+
+                # get and compare result
+                result_array = np.array(w_out)
+                self.assertEqual(list(result_array.shape), [5, 8])
+                correct = None
+                for i in range(5):
+                    if i != 3:
+                        correct = np.full((1, 8), i + 1).astype("float32")
+                        self.assertTrue((result_array[i] == correct).all())
+                    else:
+                        correct = np.full((1, 8), 0).astype("float32")
+                        self.assertTrue((result_array[i] == correct).all())
+
+    def _run_hsigmoid_op_two_pserver(self, place, port0, port1):
+        scope = fluid.core.Scope()
+        program = Program()
+        with fluid.scope_guard(scope):
+            with program_guard(program, startup_program=Program()):
+                x = scope.var('X').get_tensor()
+                x_array = np.random.random((4, 8)).astype("float32") * 2
+                x.set(x_array, place)
+                # create and initialize Param Variable
+                param = scope.var('W').get_tensor()
+                param_array = np.zeros((5, 8)).astype("float32") * 2
+                param.set(param_array, place)
+
+                path_table = scope.var('PathTable').get_tensor()
+                path_table_array = np.array(
+                    [(0, 2, -1, -1, -1), (0, 1, 3, -1, -1), (0, 1, 4, -1, -1),
+                     (0, 2, -1, -1, -1)]).astype(
+                         "int64"
+                     )  #np.array to store 1,2,5,6s' non-leaf path(root -> leaf)
+                path_table.set(path_table_array, place)
+
+                path_code = scope.var('PathCode').get_tensor()
+                path_code_array = np.array(
+                    [(0, 0, -1, -1, -1), (1, 1, 1, -1, -1), (1, 0, 0, -1, -1),
+                     (0, 1, -1, -1, -1)]).astype("int64")  #np.array to store 
+                path_code.set(path_code_array, place)
+
+                label = scope.var('Label').get_tensor()
+                label_array = np.array([0, 1, 4, 5])
+                label.set(label_array, place)
+
+                bias = scope.var('Bias').get_tensor()
+                bias_array = np.random.random((5, 1)).astype("float32")
+                bias.set(bias_array, place)
+
+                out = scope.var('Out').get_tensor()
+
+                pre_out = scope.var('PreOut').get_tensor
+
+                w_out = scope.var('W_Out').get_tensor()
+                w_out.set(param_array, place)
+
+                emaps = ['127.0.0.1:' + str(port0), '127.0.0.1:' + str(port1)]
+                table_names = ['table', 'table']
+                height_sections = [2, 3]
+
+                # create and run sgd operator
+                hsigmoid_op = Operator(
+                    "hierarchical_sigmoid",
+                    X='X',
+                    W='W',
+                    PathTable='PathTable',
+                    PathCode='PathCode',
+                    Label='Label',
+                    Bias='Bias',
+                    Out='Out',
+                    PreOut='PreOut',
+                    W_Out='W_Out',
+                    remote_prefetch=True,
+                    epmap=emaps,
+                    table_names=table_names,
+                    height_sections=height_sections)
+                hsigmoid_op.run(scope, place)
+
+                # get and compare result
+                result_array = np.array(w_out)
+                self.assertEqual(list(result_array.shape), [5, 8])
+                correct = None
+                for i in range(5):
+                    if i < 2:
+                        correct = np.full((1, 8), i + 1).astype("float32")
+                        self.assertTrue((result_array[i] == correct).all())
+                    else:
+                        correct = np.full((1, 8), i + 9).astype("float32")
+                        self.assertTrue((result_array[i] == correct).all())
+
+    def test_hsigmoid_op_remote(self):
+        os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
+        # run pserver on CPU in sync mode
+        p0 = self._start_pserver(0, False, True, run_pserver)
+        self._wait_ps_ready(p0.pid)
+        port0 = self._get_pserver_port(p0.pid)
+
+        p1 = self._start_pserver(1, False, True, run_pserver)
+        self._wait_ps_ready(p1.pid)
+        port1 = self._get_pserver_port(p1.pid)
+
+        places = [core.CPUPlace()]
+
+        for place in places:
+            self._run_hsigmoid_op_one_pserver(place, port0)
+            self._run_hsigmoid_op_two_pserver(place, port0, port1)
+
+        # raise SIGTERM to pserver
+        os.kill(p0.pid, signal.SIGINT)
+        p0.join()
+        os.kill(p1.pid, signal.SIGINT)
+        p1.join()
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py b/python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc6f40de86e302605a416c48790c74cbb431b2e3
--- /dev/null
+++ b/python/paddle/fluid/tests/unittests/test_nce_remote_table_op.py
@@ -0,0 +1,236 @@
+#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import os
+import signal
+import time
+import unittest
+from multiprocessing import Process
+
+import numpy as np
+import paddle.fluid as fluid
+import paddle.fluid.core as core
+from paddle.fluid.op import Operator
+from paddle.fluid.framework import Program, program_guard
+
+
+def nce(input, weight, bias, sample_weight, labels, num_classes,
+        num_sample_class):
+    samples = []
+    sample_labels = []
+    batch_size = input.shape[0]
+    num_true_class = labels.shape[1]
+    for i in range(batch_size):
+        w = 1 if sample_weight is None else sample_weight[i]
+        for label in labels[i]:
+            samples.append((i, label, True, w))
+            sample_labels.append(label)
+        for num in range(num_sample_class):
+            samples.append((i, num, False, w))
+            sample_labels.append(num)
+    # forward bias
+    sample_out = np.zeros(len(samples)).astype(np.float32)
+    if bias is not None:
+        for i in range(len(samples)):
+            sample_out[i] = bias[samples[i][1]]
+    # forward weight
+    for i in range(len(samples)):
+        sample_out[i] += np.dot(input[samples[i][0]], weight[samples[i][1]])
+
+    # forward activation
+    sample_out = 1.0 / (1.0 + np.exp(-sample_out))
+    # forward cost
+    out = np.zeros(batch_size).astype(np.float32)
+    b = 1.0 / num_classes * num_sample_class
+
+    for i in range(len(samples)):
+        o = sample_out[i]
+        cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b))
+        out[samples[i][0]] += cost * samples[i][3]
+    return (out[:, np.newaxis], np.array(sample_out).reshape(
+        batch_size, num_sample_class + num_true_class),
+            np.array(sample_labels).reshape(batch_size,
+                                            num_sample_class + num_true_class))
+
+
+def run_pserver(pserver_id, use_cuda, sync_mode):
+    scope = fluid.core.Scope()
+    program = Program()
+    with fluid.scope_guard(scope):
+        with program_guard(program, startup_program=Program()):
+            # create table parameter in scope
+            place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
+            # create and initialize Param Variable
+            param = scope.var('table').get_tensor()
+
+            param_array = np.ones((5, 8)).astype("float32")
+            for i in range(len(param_array)):
+                param_array[i] *= param_array[i] * i + pserver_id * 10 + 1
+            param.set(param_array, place)
+
+            optimize_block = program._create_block(program.global_block().idx)
+            program.global_block().append_op(
+                type="listen_and_serv",
+                inputs={'X': []},
+                outputs={},
+                attrs={
+                    "optimize_blocks": [optimize_block],
+                    "endpoint": '127.0.0.1:0',
+                    "Fanin": 1,
+                    "sync_mode": True,
+                    "grad_to_block_id": []
+                })
+
+            exe = fluid.Executor(place)
+            exe.run(program)
+
+
+class TestListenAndServOp(unittest.TestCase):
+    def setUp(self):
+        self.ps_timeout = 5
+
+    def _start_pserver(self, pserver_id, use_cuda, sync_mode, pserver_func):
+        p = Process(target=pserver_func, args=(pserver_id, use_cuda, sync_mode))
+        p.daemon = True
+        p.start()
+        return p
+
+    def _wait_ps_ready(self, pid):
+        start_left_time = self.ps_timeout
+        sleep_time = 0.5
+        while True:
+            assert start_left_time >= 0, "wait ps ready failed"
+            time.sleep(sleep_time)
+            try:
+                # the listen_and_serv_op would touch a file which contains the listen port
+                # on the /tmp directory until it was ready to process all the RPC call.
+                os.stat("/tmp/paddle.%d.port" % pid)
+                return
+            except os.error:
+                start_left_time -= sleep_time
+
+    def _get_pserver_port(self, pid):
+        with open("/tmp/paddle.%d.port" % pid, 'r') as f:
+            port = int(f.read().strip())
+        return port
+
+    def _run_nce_op_two_pserver(self, place, port0, port1):
+        scope = fluid.core.Scope()
+        program = Program()
+        with fluid.scope_guard(scope):
+            with program_guard(program, startup_program=Program()):
+                x = scope.var('Input').get_tensor()
+                x_array = np.random.random((4, 8)).astype("float32")
+                x.set(x_array, place)
+                # create and initialize Param Variable
+                param = scope.var('Weight').get_tensor()
+                param_array = np.zeros((5, 8)).astype("float32")
+                param.set(param_array, place)
+
+                bias = scope.var('Bias').get_tensor()
+                bias_array = np.random.random((5, 1)).astype("float32")
+                bias.set(bias_array, place)
+
+                sample_w = scope.var('SampleWeight').get_tensor()
+                sample_weight = np.random.random((4, 1)).astype("float32")
+                sample_w.set(sample_weight, place)
+
+                label = scope.var('Label').get_tensor()
+                label_array = np.array([[0], [1], [4], [3]])
+                label.set(label_array, place)
+
+                cost = scope.var('Cost').get_tensor()
+                cost_w = np.zeros((4, 1)).astype("float32")
+                cost.set(cost_w, place)
+
+                sample_l = scope.var('SampleLogits').get_tensor()
+                sample_l_w = np.zeros((4, 3)).astype("float32")
+                sample_l.set(sample_l_w, place)
+
+                sample_la = scope.var('SampleLabels').get_tensor()
+                sample_la_w = np.zeros((4, 3)).astype("int")
+                sample_la.set(sample_la_w, place)
+
+                emaps = ['127.0.0.1:' + str(port0), '127.0.0.1:' + str(port1)]
+                table_names = ['table', 'table']
+                height_sections = [2, 3]
+
+                # create and run nce operator
+                nce_op = Operator(
+                    "nce",
+                    Input='Input',
+                    Weight='Weight',
+                    Label='Label',
+                    Bias='Bias',
+                    Cost='Cost',
+                    SampleLogits='SampleLogits',
+                    SampleLabels='SampleLabels',
+                    SampleWeight='SampleWeight',
+                    num_total_classes=5,
+                    num_neg_samples=2,
+                    custom_neg_classes=list(range(2)),
+                    sampler=0,
+                    seed=0,
+                    is_sparse=True,
+                    remote_prefetch=True,
+                    epmap=emaps,
+                    table_names=table_names,
+                    height_sections=height_sections)
+
+                nce_op.run(scope, place)
+
+                # get and compare result
+                o_cost = np.array(scope.var('Cost').get_tensor())
+                o_logits = np.array(scope.var('SampleLogits').get_tensor())
+                o_labels = np.array(scope.var('SampleLabels').get_tensor())
+
+                param_array = np.ones((5, 8)).astype("float32")
+                for i in range(2):
+                    param_array[i] *= param_array[i] * i + 0 * 10 + 1
+                for i in range(2, 5):
+                    param_array[i] *= param_array[i] * i + 1 * 10 + 1
+                out = nce(x_array, param_array, bias_array, sample_weight,
+                          label_array, 5, 2)
+
+                self.assertAlmostEqual(o_cost.all(), out[0].all(), delta=1e-6)
+                self.assertAlmostEqual(o_logits.all(), out[1].all(), delta=1e-6)
+                self.assertAlmostEqual(o_labels.all(), out[2].all(), delta=1e-6)
+
+    def test_nce_op_remote(self):
+        os.environ['PADDLE_ENABLE_REMOTE_PREFETCH'] = "1"
+        # run pserver on CPU in sync mode
+        p0 = self._start_pserver(0, False, True, run_pserver)
+        self._wait_ps_ready(p0.pid)
+        port0 = self._get_pserver_port(p0.pid)
+
+        p1 = self._start_pserver(1, False, True, run_pserver)
+        self._wait_ps_ready(p1.pid)
+        port1 = self._get_pserver_port(p1.pid)
+
+        places = [core.CPUPlace()]
+
+        for place in places:
+            self._run_nce_op_two_pserver(place, port0, port1)
+
+        # raise SIGTERM to pserver
+        os.kill(p0.pid, signal.SIGINT)
+        p0.join()
+        os.kill(p1.pid, signal.SIGINT)
+        p1.join()
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py
index c128843885fbce29893a4b24c65482abaf870e82..07343b4051e0f44996d1d4617e2cbd1a0d22ce3e 100644
--- a/python/paddle/fluid/transpiler/distribute_transpiler.py
+++ b/python/paddle/fluid/transpiler/distribute_transpiler.py
@@ -251,11 +251,10 @@ class DistributeTranspiler(object):
 
     def _get_all_remote_sparse_update_op(self, main_program):
         sparse_update_ops = []
-        sparse_update_op_types = ["lookup_table"]
+        sparse_update_op_types = ["lookup_table", "nce", "hierarchical_sigmoid"]
         for op in main_program.global_block().ops:
             if op.type in sparse_update_op_types and op.attr(
-                    'remote_prefetch') is True and not op.attr(
-                        'is_distributed'):
+                    'remote_prefetch') is True:
                 sparse_update_ops.append(op)
         return sparse_update_ops