diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc
index ecca0d4f7a0f6d2fff91f363ae99eefba84ed975..bdf7e7c1248d8378c9e66389dc4297294982f1c7 100644
--- a/paddle/fluid/framework/executor.cc
+++ b/paddle/fluid/framework/executor.cc
@@ -397,72 +397,72 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
   }
   platform::DeviceContextPool::Instance().Get(place_)->Wait();
 
-  VLOG(3) << "start checking";
-    auto& dev_ctx = *platform::DeviceContextPool::Instance().Get(place_);
-  std::vector<std::string> outputs;
-  auto& block = ctx->prog_.Block(0);
-
-  for(auto& op : block.AllOps()) {
-    if(op->Type() == "load_combine" || op->Type() == "fetch" || op->Type() == "feed") continue;
-    // for(auto& real_op : ctx->ops_) {
-    //   if(real_op->Type() == op->Type()) {
-    //     VLOG(3) << real_op->Type() << " " <<place_ << " " << real_op->DebugStringEx(local_scope);
-    //   }
-    // }
+  // VLOG(3) << "start checking";
+  //   auto& dev_ctx = *platform::DeviceContextPool::Instance().Get(place_);
+  // std::vector<std::string> outputs;
+  // auto& block = ctx->prog_.Block(0);
+
+  // for(auto& op : block.AllOps()) {
+  //   if(op->Type() == "load_combine" || op->Type() == "fetch" || op->Type() == "feed") continue;
+  //   // for(auto& real_op : ctx->ops_) {
+  //   //   if(real_op->Type() == op->Type()) {
+  //   //     VLOG(3) << real_op->Type() << " " <<place_ << " " << real_op->DebugStringEx(local_scope);
+  //   //   }
+  //   // }
      
-     //VLOG(3) << "start op output" << op->Type();
-      for(auto var_name: op->InputArgumentNames()) {
-      auto* var = local_scope->Var(var_name);
-      auto* var_desc = block.FindVar(var_name);
-      if (var_desc->Persistable()) continue;
-      auto* tensor = var->GetMutable<framework::LoDTensor>();
-      framework::Tensor check;
-      VLOG(3) << "before tensor copy";
+  //    //VLOG(3) << "start op output" << op->Type();
+  //     for(auto var_name: op->InputArgumentNames()) {
+  //     auto* var = local_scope->Var(var_name);
+  //     auto* var_desc = block.FindVar(var_name);
+  //     if (var_desc->Persistable()) continue;
+  //     auto* tensor = var->GetMutable<framework::LoDTensor>();
+  //     framework::Tensor check;
+  //     VLOG(3) << "before tensor copy";
    
-      framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, &check);
+  //     framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, &check);
       
-      VLOG(3) << "after tensor copy";
-      float sum = .0;
-      for(size_t i=0; i < check.numel(); ++i) {
-        if(std::type_index(check.type()) == std::type_index(typeid(int64_t))) {
-          sum += static_cast<float>(check.data<int64_t>()[i]);
-        } else {
-          sum += check.data<float>()[i];
-        }
-      }
-      VLOG(3) << "op " << op->Type() << " input var " << var_name << " sum " << sum;
-    }
-
-    VLOG(3) << "op " << op->Type() << "input finished";
-    for(auto var_name: op->OutputArgumentNames()) {
-      auto* var = local_scope->Var(var_name);
-      auto* var_desc = block.FindVar(var_name);
-      if (var_desc->Persistable()) continue;
-      auto* tensor = var->GetMutable<framework::LoDTensor>();
-      framework::Tensor check;
-      VLOG(3) << "before tensor copy";
-      if(op->Type() == "batch_norm" && platform::is_gpu_place(place_)) {
-        VLOG(3) << "op " << op->Type() << " output var " << var_name << " " << tensor->numel();
-        tensor->mutable_data<float>(place_);
-         framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, &check);
-      } else {
-         framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, &check);
-      }
+  //     VLOG(3) << "after tensor copy";
+  //     float sum = .0;
+  //     for(size_t i=0; i < check.numel(); ++i) {
+  //       if(std::type_index(check.type()) == std::type_index(typeid(int64_t))) {
+  //         sum += static_cast<float>(check.data<int64_t>()[i]);
+  //       } else {
+  //         sum += check.data<float>()[i];
+  //       }
+  //     }
+  //     VLOG(3) << "op " << op->Type() << " input var " << var_name << " sum " << sum;
+  //   }
+
+  //   VLOG(3) << "op " << op->Type() << "input finished";
+  //   for(auto var_name: op->OutputArgumentNames()) {
+  //     auto* var = local_scope->Var(var_name);
+  //     auto* var_desc = block.FindVar(var_name);
+  //     if (var_desc->Persistable()) continue;
+  //     auto* tensor = var->GetMutable<framework::LoDTensor>();
+  //     framework::Tensor check;
+  //     VLOG(3) << "before tensor copy";
+  //     if(op->Type() == "batch_norm" && platform::is_gpu_place(place_)) {
+  //       VLOG(3) << "op " << op->Type() << " output var " << var_name << " " << tensor->numel();
+  //       tensor->mutable_data<float>(place_);
+  //        framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, &check);
+  //     } else {
+  //        framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, &check);
+  //     }
       
-      VLOG(3) << "after tensor copy";
-      float sum = .0;
-      for(size_t i=0; i < check.numel(); ++i) {
-        if(std::type_index(check.type()) == std::type_index(typeid(int64_t))) {
-          sum += static_cast<float>(check.data<int64_t>()[i]);
-        } else {
-          sum += check.data<float>()[i];
-        }
-      }
-      VLOG(3) << "op " << op->Type() << " output var " << var_name << " sum " << sum;
-    }
-  }
-
-  VLOG(3) << "after checking result";
+  //     VLOG(3) << "after tensor copy";
+  //     float sum = .0;
+  //     for(size_t i=0; i < check.numel(); ++i) {
+  //       if(std::type_index(check.type()) == std::type_index(typeid(int64_t))) {
+  //         sum += static_cast<float>(check.data<int64_t>()[i]);
+  //       } else {
+  //         sum += check.data<float>()[i];
+  //       }
+  //     }
+  //     VLOG(3) << "op " << op->Type() << " output var " << var_name << " sum " << sum;
+  //   }
+  // }
+
+  // VLOG(3) << "after checking result";
 
   if (local_scope != scope) {
     scope->DeleteScope(local_scope);
diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc
index 0ed9bab2464215cb7656f0c0835d6d36dcfbdfa8..aaf6d5a4f3051ff5f3f6b1afc047c9ebd936fb5b 100644
--- a/paddle/fluid/inference/api/api_impl.cc
+++ b/paddle/fluid/inference/api/api_impl.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License. */
 
 #include <algorithm>
+#include <fstream>
 #include <map>
 #include <set>
 #include <sstream>
@@ -88,6 +89,7 @@ bool NativePaddlePredictor::Init(
     VLOG(3) << config_.model_dir;
     inference_program_ = paddle::inference::Load(executor_.get(), scope_.get(),
                                                  config_.model_dir);
+
     VLOG(3) << "load model finish";
   } else if (!config_.prog_file.empty() && !config_.param_file.empty()) {
     // All parameters are saved in a single file.
@@ -100,6 +102,31 @@ bool NativePaddlePredictor::Init(
     VLOG(3) << "scope_";
     inference_program_ = paddle::inference::Load(
         executor_.get(), scope_.get(), config_.prog_file, config_.param_file);
+    // VLOG(3) << "modify the program!";
+    // {
+    //   std::ofstream ofs("program.txt", std::ios::out);
+    //   std::string s = inference_program_->Proto()->SerializeAsString();
+    //   ofs.write(s.data(), s.size());
+    //   ofs.close();
+    // }
+
+    auto &block = inference_program_->Block(0);
+    for (auto *op_desc : block.AllOps()) {
+      if (op_desc->HasAttr("use_cudnn")) {
+        op_desc->SetAttr("use_cudnn", false);
+      }
+      if (op_desc->HasAttr("workspace_size_MB")) {
+        op_desc->SetAttr("workspace_size_MB", 0);
+      }
+    }
+
+    // {
+    //   std::ofstream ofs("after_program.txt", std::ios::out);
+    //   std::string s = inference_program_->Proto()->SerializeAsString();
+    //   ofs.write(s.data(), s.size());
+    //   ofs.close();
+    // }
+
     VLOG(3) << "load program finish";
   } else {
     LOG(ERROR) << "fail to load inference model.";
@@ -306,9 +333,10 @@ std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
   if (config.use_gpu) {
     // 1. GPU memeroy
     VLOG(3) << "before check";
-   // PADDLE_ENFORCE_GT(
+    // PADDLE_ENFORCE_GT(
     //    config.fraction_of_gpu_memory, 0.f,
-    //    "fraction_of_gpu_memory in the config should be set to range (0., 1.]");
+    //    "fraction_of_gpu_memory in the config should be set to range (0.,
+    //    1.]");
     VLOG(3) << "failed on first";
     PADDLE_ENFORCE_GE(config.device, 0, "Invalid device id %d", config.device);
     VLOG(3) << "after flags";
diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt
index c82837a490ef63f205ba5d327a16a91dcb4275a5..db2a7acfda3bd33e930e72f109aaf25ad25721ce 100644
--- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt
+++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt
@@ -77,7 +77,7 @@ add_executable(real_data_icnet_tester real_data_icnet_tester.cc)
 
 # add_library(${DEMO_NAME} SHARED  ${DEMO_NAME}.cc)
 # add_executable(test test.cc)
-# add_executable(thread_icnet_test thread_icnet_test.cc)
+add_executable(thread_icnet_test thread_icnet_test.cc)
 
 if(WITH_MKL)
   include_directories("${PADDLE_LIB}/third_party/install/mklml/include")
@@ -130,6 +130,5 @@ target_link_libraries(real_data_icnet_tester ${DEPS})
 
 # target_link_libraries(${DEMO_NAME} ${DEPS})
 # target_link_libraries(test ${DEMO_NAME} )
-# target_link_libraries(thread_icnet_test ${DEPS})
+target_link_libraries(thread_icnet_test ${DEPS})
 # target_compile_definitions(${DEMO_NAME} PRIVATE "API_DEFINITION")
-
diff --git a/paddle/fluid/inference/api/demo_ci/real_data_icnet_tester.cc b/paddle/fluid/inference/api/demo_ci/real_data_icnet_tester.cc
index ae5f130504e8d5a7db9e0605b2805523577748b7..1b6463a333cffafc67540ffbb537c9c96ee36903 100644
--- a/paddle/fluid/inference/api/demo_ci/real_data_icnet_tester.cc
+++ b/paddle/fluid/inference/api/demo_ci/real_data_icnet_tester.cc
@@ -25,10 +25,13 @@ namespace paddle {
 
 NativeConfig GetConfig() {
   NativeConfig config;
+
   // config.model_dir = FLAGS_dirname;
-  config.prog_file= "hs_lb_without_bn/__model__";
-  config.param_file= "hs_lb_without_bn/__params__";
-  config.fraction_of_gpu_memory = 0.8;
+  config.prog_file = "hs_lb_without_bn/__model__";
+  config.param_file = "hs_lb_without_bn/__params__";
+  // config.prog_file = "hs_lb_without_bn_cuda/__model__";
+  // config.param_file = "hs_lb_without_bn_cuda/__params__";
+  config.fraction_of_gpu_memory = 0.0;
   config.use_gpu = true;
   config.device = 0;
   return config;
@@ -43,13 +46,12 @@ double time_diff(Time t1, Time t2) {
   return counter.count() / 1000.0;
 }
 
-
-void test_naive(int batch_size){
+void test_naive(int batch_size) {
   NativeConfig config = GetConfig();
   auto predictor = CreatePaddlePredictor<NativeConfig>(config);
   int height = 449;
   int width = 581;
-  
+
   // =============read file list =============
   std::ifstream infile("new_file.list");
   std::string temp_s;
@@ -62,61 +64,65 @@ void test_naive(int batch_size){
   // size_t file_num = all_files.size();
   infile.close();
   // =============read file list =============
-  for (size_t f_k = 0; f_k < 1; f_k ++) {
-          std::ifstream in_img(all_files[f_k]);
-          std::cout << all_files[f_k] << std::endl;
-          float temp_v;
+  for (size_t f_k = 0; f_k < 1; f_k++) {
+    std::ifstream in_img(all_files[f_k]);
+    std::cout << all_files[f_k] << std::endl;
+    float temp_v;
 
-         float sum_n = 0.0;
-	 std::vector<float> data;
-         while (!in_img.eof()) {
-            in_img >> temp_v;
-            data.push_back(float(temp_v));
-            // std::cout << temp_v << " ";
-            sum_n += temp_v;
-         }
+    float sum_n = 0.0;
+    std::vector<float> data;
+    while (!in_img.eof()) {
+      in_img >> temp_v;
+      data.push_back(float(temp_v));
+      // std::cout << temp_v << " ";
+      sum_n += temp_v;
+    }
 
-          in_img.close();
-          std::cout << "sum: " << sum_n << std::endl;
-          
-	  PaddleTensor tensor;
-	  tensor.shape = std::vector<int>({batch_size, 3, height, width});
-          tensor.data.Resize(sizeof(float) * batch_size * 3 * height * width);
-          std::copy(data.begin(), data.end(), static_cast<float*>(tensor.data.data()));
-	  tensor.dtype = PaddleDType::FLOAT32;
-	  std::vector<PaddleTensor> paddle_tensor_feeds(1, tensor);
-	  PaddleTensor tensor_out;
+    in_img.close();
+    std::cout << "sum: " << sum_n << std::endl;
 
-	  std::vector<PaddleTensor> outputs(1, tensor_out);
-	  // predictor->Run(paddle_tensor_feeds, &outputs, batch_size);
-	  std::cout << "start predict123:" << std::endl;
-	  auto time1 = time(); 
-	  
-	  for(size_t i = 0; i < 1; i++) {
-	    predictor->Run(paddle_tensor_feeds, &outputs, batch_size);
-	  } 
+    PaddleTensor tensor;
+    tensor.shape = std::vector<int>({batch_size, 3, height, width});
+    tensor.data.Resize(sizeof(float) * batch_size * 3 * height * width);
+    std::copy(data.begin(), data.end(),
+              static_cast<float*>(tensor.data.data()));
+    tensor.dtype = PaddleDType::FLOAT32;
+    std::vector<PaddleTensor> paddle_tensor_feeds(1, tensor);
+    PaddleTensor tensor_out;
 
-	  auto time2 = time(); 
-	  std::ofstream ofresult("naive_test_result.txt", std::ios::app);
+    std::vector<PaddleTensor> outputs(1, tensor_out);
+    // predictor->Run(paddle_tensor_feeds, &outputs, batch_size);
+    std::cout << "start predict123:" << std::endl;
+    auto time1 = time();
+    int steps = 100;
+    for (size_t i = 0; i < steps; i++) {
+      if (i == 5) time1 = time();
+      predictor->Run(paddle_tensor_feeds, &outputs, batch_size);
+    }
 
-	  std::cout <<"batch: " << batch_size << " predict cost: " << time_diff(time1, time2) / 1000.0 << "ms" << std::endl;
-          std::cout << outputs.size() << std::endl;
-	  int64_t * data_o = static_cast<int64_t*>(outputs[0].data.data());
+    auto time2 = time();
+    std::ofstream ofresult("naive_test_result.txt", std::ios::app);
+
+    std::cout << "batch: " << batch_size
+              << " predict cost: " << time_diff(time1, time2) / steps << "ms"
+              << std::endl;
+    std::cout << outputs.size() << std::endl; 
+    int64_t* data_o = static_cast<int64_t*>(outputs[0].data.data());
     int64_t sum_out = 0;
-	  for (size_t j = 0; j < outputs[0].data.length() / sizeof(int64_t); ++j) {
-	    ofresult << std::to_string(data_o[j]) << " ";
+    for (size_t j = 0; j < outputs[0].data.length() / sizeof(int64_t); ++j) {
+      ofresult << std::to_string(data_o[j]) << " ";
       sum_out += data_o[j];
-	  }
+    }
     std::cout << "sum_out " << sum_out << std::endl;
-	  ofresult << std::endl;
-	  ofresult.close();
- }
+    ofresult << std::endl;
+    ofresult.close();
+  }
 }
 
 }  // namespace paddle
 
 int main(int argc, char** argv) {
-//  google::ParseCommandLineFlags(&argc, &argv, true);
-  paddle::test_naive(1<<0);
+  //  google::ParseCommandLineFlags(&argc, &argv, true);
+  paddle::test_naive(1 << 0);
   return 0;
 }
diff --git a/paddle/fluid/inference/api/demo_ci/thread_icnet_test.cc b/paddle/fluid/inference/api/demo_ci/thread_icnet_test.cc
index d669b04dc91f2823216c0aa5d067142ee164b869..9a018ee347e821035d14a27bbab236fd01c2d48b 100644
--- a/paddle/fluid/inference/api/demo_ci/thread_icnet_test.cc
+++ b/paddle/fluid/inference/api/demo_ci/thread_icnet_test.cc
@@ -20,22 +20,21 @@
 #include <chrono>
 #include <fstream>
 #include <iostream>
-#include "paddle/fluid/inference/api/paddle_inference_api.h"
 #include <thread>  // NOLINT
+#include "paddle/fluid/inference/api/paddle_inference_api.h"
 
 #define ASSERT_TRUE(x) x
 #define ASSERT_EQ(x, y) assert(x == y)
 
-namespace paddle {
 
 // DEFINE_string(dirname, "./LB_icnet_model",
 //               "Directory of the inference model.");
-
+namespace paddle {
 NativeConfig GetConfig() {
   NativeConfig config;
-  config.prog_file= "./dzh_lb/__model__";
-  config.param_file= "./dzh_lb/__params__";
-  config.fraction_of_gpu_memory = 0.08;
+  config.prog_file = "./hs_lb_without_bn_cuda/__model__";
+  config.param_file = "./hs_lb_without_bn_cuda/__params__";
+  config.fraction_of_gpu_memory = 0.5;
   config.use_gpu = true;
   config.device = 0;
   return config;
@@ -50,56 +49,84 @@ double time_diff(Time t1, Time t2) {
   return counter.count() / 1000.0;
 }
 
-void test_naive(int batch_size, std::string model_path){
-  PaddlePredictor* pres[2];
-  
+void test_naive(int batch_size, std::string model_path) {
   NativeConfig config = GetConfig();
-  // config.model_dir = model_path;
-  auto predictor0 = CreatePaddlePredictor<NativeConfig>(config);
-  auto predictor1 = CreatePaddlePredictor<NativeConfig>(config);
-  pres[0] = predictor0.get();
-  pres[1] = predictor1.get();
-
   int height = 449;
   int width = 581;
-  
   std::vector<float> data;
-  for (int i = 0; i < 3 * height * width; i++) {
-    data.push_back(0);
-  }
-  
-  PaddleTensor tensor;
-  tensor.shape = std::vector<int>({batch_size, 3, height, width});
-  tensor.data.Resize(sizeof(float) * batch_size * 3 * height * width);
-  std::copy(data.begin(), data.end(), static_cast<float*>(tensor.data.data()));
-  tensor.dtype = PaddleDType::FLOAT32;
-  std::vector<PaddleTensor> paddle_tensor_feeds(1, tensor);
-
-  constexpr int num_jobs = 5;  // each job run 1 batch
-  std::vector<std::thread> threads;
-  for (int tid = 0; tid < num_jobs; ++tid) {
-    threads.emplace_back([&, tid]() {
-      auto predictor = pres[tid];
-      std::vector<PaddleTensor> local_outputs;
-     for(size_t i = 0; i < 1000; i++) {
-      ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &local_outputs));
-      std::cout << "run: " << tid << std::endl; 
-      }
-      ASSERT_EQ(local_outputs.size(), 1UL);
-    });
+  for(int i=0; i < 3 * height * width; ++i) {
+    data.push_back(0.0);
   }
-  for (int i = 0; i < num_jobs; ++i) {
-    threads[i].join();
-  }
-}
 
-//TEST(alexnet, naive) {
-//  test_naive(1 << 0, "./trt_models/vgg19");
-//}
+  // read data
+  // std::ifstream infile("new_file.list");
+  // std::string temp_s;
+  // std::vector<std::string> all_files;
+  // while (!infile.eof()) {
+  //   infile >> temp_s;
+  //   all_files.push_back(temp_s);
+  // }
 
-}  // namespace paddle
+  // // size_t file_num = all_files.size();
+  // infile.close();
+  // // =============read file list =============
+  // for (size_t f_k = 0; f_k < 1; f_k++) {
+  //   std::ifstream in_img(all_files[f_k]);
+  //   std::cout << all_files[f_k] << std::endl;
+  //   float temp_v;
 
-int main(int argc, char** argv) {
-	paddle::test_naive(1 << 0, "");
-}
+  //   float sum_n = 0.0;
+  //   std::vector<float> data;
+  //   while (!in_img.eof()) {
+  //     in_img >> temp_v;
+  //     data.push_back(float(temp_v));
+
+  //     sum_n += temp_v;
+  //   }
+  //   in_img.close();
+  //   std::cout << "sum: " << sum_n << std::endl;
+
+    PaddleTensor tensor;
+    tensor.shape = std::vector<int>({batch_size, 3, height, width});
+    tensor.data.Resize(sizeof(float) * batch_size * 3 * height * width);
+    std::copy(data.begin(), data.end(),
+              static_cast<float*>(tensor.data.data()));
+    tensor.dtype = PaddleDType::FLOAT32;
+    std::vector<PaddleTensor> paddle_tensor_feeds(1, tensor);
+
+    constexpr int num_jobs = 2;  // each job run 1 batch
+    std::vector<std::thread> threads;
 
+
+    for (int tid = 0; tid < num_jobs; ++tid) {
+      threads.emplace_back([&, tid]() {
+      PaddleTensor tensor_out;
+      std::vector<PaddleTensor> outputs(1, tensor_out);
+        auto predictor = CreatePaddlePredictor<NativeConfig>(config);
+        for (size_t i = 0; i < 1000; i++) {
+          ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
+          VLOG(0) << "tid : " << tid << " run: " << i << "finished";
+          //std::cout <<"tid : " << tid << " run: " << i << "finished" << std::endl;
+          ASSERT_EQ(outputs.size(), 1UL);
+          // int64_t* data_o = static_cast<int64_t*>(outputs[0].data.data());
+          // int64_t sum_out = 0;
+          // for (size_t j = 0; j < outputs[0].data.length() / sizeof(int64_t);
+          //      ++j) {
+          //   sum_out += data_o[j];
+          // }
+          // std::cout << "tid : " << tid << "pass : " << i << " " << sum_out
+          //           << std::endl;
+        }
+      });
+    }
+    for (int i = 0; i < num_jobs; ++i) {
+      threads[i].join();
+    }
+  }
+// }
+} // namespace paddle
+
+  int main(int argc, char** argv) { 
+    paddle::test_naive(1 << 0, ""); 
+    return 0;
+}
diff --git a/paddle/fluid/operators/conv_cudnn_op.cu.cc b/paddle/fluid/operators/conv_cudnn_op.cu.cc
index 5bee83c9abb00e0ab097b02d7e12b74cc10d66ad..7e859c1bcc00d69b834ccfb50daee23fb0c0e886 100644
--- a/paddle/fluid/operators/conv_cudnn_op.cu.cc
+++ b/paddle/fluid/operators/conv_cudnn_op.cu.cc
@@ -163,7 +163,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
     VLOG(3) << "after get workspace";
     // Allocate on GPU memory
     platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
-    workspace_size_in_bytes = 1024;
+    // workspace_size_in_bytes = 1024;
     cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
     VLOG(3) << "allocate memory";
     // ------------------- cudnn conv forward ---------------------
@@ -324,7 +324,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
     // Already on GPU
     void* cudnn_workspace = nullptr;
     platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
-    workspace_size_in_bytes = 1024;
+    //workspace_size_in_bytes = 1024;
     cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
     // ------------------- cudnn conv backward data ---------------------
     ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc
index 14c0a464543ee2206201577cf1526de53336e424..267313b7f8ac2a69bd2d821f4d942410ce8ce939 100644
--- a/paddle/fluid/operators/load_combine_op.cc
+++ b/paddle/fluid/operators/load_combine_op.cc
@@ -62,18 +62,18 @@ class LoadCombineOp : public framework::OperatorBase {
       VLOG(3) << "before deserialization";
       // Get data from fin to tensor
       DeserializeFromStream(fin, tensor, dev_ctx); 
-      VLOG(3) << "after deserialization";
-      framework::Tensor check;
-      framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, &check);
-      float sum = .0;
-      for(size_t i=0; i < check.numel(); ++i) {
-        if(std::type_index(check.type()) == std::type_index(typeid(int64_t))) {
-          sum += static_cast<float>(check.data<int64_t>()[i]);
-        } else {
-          sum += check.data<float>()[i];
-        }
-      }
-      VLOG(3) << "sum result" << sum;
+      // VLOG(3) << "after deserialization";
+      // framework::Tensor check;
+      // framework::TensorCopy(*tensor, platform::CPUPlace(), dev_ctx, &check);
+      // float sum = .0;
+      // for(size_t i=0; i < check.numel(); ++i) {
+      //   if(std::type_index(check.type()) == std::type_index(typeid(int64_t))) {
+      //     sum += static_cast<float>(check.data<int64_t>()[i]);
+      //   } else {
+      //     sum += check.data<float>()[i];
+      //   }
+      // }
+      // VLOG(3) << "sum result" << sum;
       auto in_dtype = framework::ToDataType(tensor->type());
       auto out_dtype =
           load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc
index 4a8ac441cfaf642fde58ee30865a22e83c065498..c17d1afc309c65035063348d4934ea1783b018ed 100644
--- a/paddle/fluid/operators/top_k_op.cc
+++ b/paddle/fluid/operators/top_k_op.cc
@@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
  public:
   void Make() override {
     AddInput("X", "(Tensor) The input of Topk op");
-    AddOutput("Out", "(Tensor) The output tensor of Topk op").Reuse("X");
+    AddOutput("Out", "(Tensor) The output tensor of Topk op");
     AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
     AddComment(R"DOC(
 Top K operator
diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu
index 9da8551eb2d7ea66ad434c42b54522432095ce29..0cad224ca8860b0e4bc2e3f2bc1659235aadfe2d 100644
--- a/paddle/fluid/operators/top_k_op.cu
+++ b/paddle/fluid/operators/top_k_op.cu
@@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
  * 3. go to the second setp, until one thread's topk value is null;
  * 4. go to the first setp, until get the topk value.
  */
+
 template <typename T, int MaxLength, int BlockSize>
 __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
-                             const T* src, int lds, int dim, int k) {
+                             const T* src, int lds, int dim, int k,
+                             int grid_dim, int num) {
   __shared__ Pair<T> sh_topk[BlockSize];
-  __shared__ int maxid[BlockSize / 2];
   const int tid = threadIdx.x;
   const int warp = threadIdx.x / 32;
-  output += blockIdx.x * output_stride;
-  indices += blockIdx.x * k;
 
-  Pair<T> topk[MaxLength];
-  int beam = MaxLength;
-  Pair<T> max;
-  bool is_empty = false;
-  bool firststep = true;
+  const int bid = blockIdx.x;
+  for (int i = bid; i < num; i += grid_dim) {
+    int top_num = k;
+    __shared__ int maxid[BlockSize / 2];
+    T* out = output + i * output_stride;
+    int64_t* inds = indices + i * k;
+    Pair<T> topk[MaxLength];
+    int beam = MaxLength;
+    Pair<T> max;
+    bool is_empty = false;
+    bool firststep = true;
+
+    for (int j = 0; j < MaxLength; j++) {
+      topk[j].set(-INFINITY, -1);
+    }
+    while (top_num) {
+      ThreadGetTopK<T, MaxLength, BlockSize>(
+          topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
 
-  for (int k = 0; k < MaxLength; k++) {
-    topk[k].set(-INFINITY, -1);
+      sh_topk[tid] = topk[0];
+      BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds,
+                                           &beam, &top_num, tid, warp);
+    }
   }
-  while (k) {
-    ThreadGetTopK<T, MaxLength, BlockSize>(topk, &beam, k,
-                                           src + blockIdx.x * lds, &firststep,
-                                           &is_empty, &max, dim, tid);
-
-    sh_topk[tid] = topk[0];
-    BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
-                                         &indices, &beam, &k, tid, warp);
+}
+
+inline static int GetDesiredBlockDim(int dim) {
+  if (dim > 128) {
+    return 256;
+  } else if (dim > 64) {
+    return 128;
+  } else if (dim > 32) {
+    return 64;
+  } else {
+    return 32;
   }
 }
 
+#define FIXED_BLOCK_DIM_BASE(dim, ...) \
+  case (dim): {                        \
+    constexpr auto kBlockDim = (dim);  \
+    __VA_ARGS__;                       \
+  } break
+
+#define FIXED_BLOCK_DIM(...)                \
+  FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
+  FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
+  FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__);  \
+  FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
+
 template <typename T>
 class TopkOpCUDAKernel : public framework::OpKernel<T> {
  public:
@@ -298,30 +327,38 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
     size_t k = static_cast<int>(ctx.Attr<int>("k"));
 
     const T* input_data = input->data<T>();
-
     T* output_data = output->mutable_data<T>(ctx.GetPlace());
     // FIXME(typhoonzero): data is always converted to type T?
     int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
 
-    size_t input_height = input->dims()[0];
-    size_t input_width = input->dims()[1];
+    framework::DDim inputdims = input->dims();
+    const size_t input_height = framework::product(
+        framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
+    const size_t input_width = inputdims[inputdims.size() - 1];
+
     if (k > input_width) k = input_width;
 
     // NOTE: pass lds and dim same to input width.
     // NOTE: old matrix implementation of stride is different to eigen.
     // TODO(typhoonzero): refine this kernel.
-    dim3 threads(256, 1);
-    dim3 grid(input_height, 1);
-
-    KeMatrixTopK<T, 5, 256><<<
-        grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
-                              ctx.device_context())
-                              .stream()>>>(
-        output_data, output->dims()[1], indices_data, input_data, input_width,
-        input_width, static_cast<int>(k));
+    const int kMaxHeight = 2048;
+    int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
+    auto& dev_ctx = ctx.cuda_device_context();
+    switch (GetDesiredBlockDim(input_width)) {
+      FIXED_BLOCK_DIM(
+          KeMatrixTopK<T, 5,
+                       kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
+              output_data, k, indices_data, input_data, input_width,
+              input_width, static_cast<int>(k), gridx, input_height));
+      default:
+        PADDLE_THROW("Error");
+    }
   }
 };
 
+#undef FIXED_BLOCK_DIM_BASE
+#undef FIXED_BLOCK_DIM
+
 }  // namespace operators
 }  // namespace paddle
 
diff --git a/paddle/fluid/operators/top_k_op.h b/paddle/fluid/operators/top_k_op.h
index 054dd481994d03f71b0ed5dc73e103085f6c91aa..76ece57b39919148da04caecaa43ea9d2b9d95df 100644
--- a/paddle/fluid/operators/top_k_op.h
+++ b/paddle/fluid/operators/top_k_op.h
@@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel<T> {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
     // Get the top k elements of each row of input tensor
-    // FIXME: only deal with matrix(2d tensor).
     auto* input = ctx.Input<Tensor>("X");
     auto* output = ctx.Output<Tensor>("Out");
     auto* indices = ctx.Output<Tensor>("Indices");
@@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel<T> {
     T* output_data = output->mutable_data<T>(ctx.GetPlace());
     int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
 
-    auto eg_input = EigenMatrix<T>::From(*input);
-
     // reshape input to a flattern matrix(like flat_inner_dims)
     framework::DDim inputdims = input->dims();
     const size_t row = framework::product(
@@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel<T> {
     const size_t col = inputdims[inputdims.size() - 1];
     Eigen::DSizes<int, 2> flat2dims(row, col);
     // NOTE: eigen shape doesn't affect paddle tensor.
-    eg_input.reshape(flat2dims);
+    auto eg_input = EigenMatrix<T>::Reshape(*input, inputdims.size() - 1);
 
 #ifdef PADDLE_WITH_MKLML
 #pragma omp parallel for