diff --git a/cmake/operators.cmake b/cmake/operators.cmake
index 134c894392a604875780fcfc8ea93e06c9d48bdd..f43d284ad0b874d4c659cf9f2e87666b6a000d90 100644
--- a/cmake/operators.cmake
+++ b/cmake/operators.cmake
@@ -191,9 +191,6 @@ function(op_library TARGET)
         file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
       elseif(${TARGET} STREQUAL "tensorrt_engine_op")
           message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
-      elseif(${TARGET} STREQUAL "fc")
-        # HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition
-        file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
       else()
         file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
       endif()
diff --git a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
index 6462e7bf4c099a1abb98a77d905067628b8eb88c..b29b37992dddd1caae5222f1aab403d76c13dad7 100644
--- a/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
+++ b/paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.cc
@@ -21,7 +21,6 @@
 
 #include "paddle/fluid/operators/math/blas.h"
 #include "paddle/fluid/operators/math/cpu_vec.h"
-#include "paddle/fluid/operators/math/fc_compute.h"
 #include "paddle/fluid/platform/cpu_info.h"
 
 namespace paddle {
diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc
index 1b58243aaa3bda4df49c57360e36e83079c27ff8..83d10abd5f3ef872d71175369e343373c2d07a07 100644
--- a/paddle/fluid/inference/api/paddle_pass_builder.cc
+++ b/paddle/fluid/inference/api/paddle_pass_builder.cc
@@ -103,9 +103,10 @@ const std::vector<std::string> kAnakinSubgraphPasses({
 
 GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
   passes_.assign({
-    //   "identity_scale_op_clean_pass",              //
+    //   "identity_scale_op_clean_pass",             //
     "is_test_pass",                                  //
         "simplify_with_basic_ops_pass",              //
+        "fc_fuse_pass",                              //
         "conv_affine_channel_fuse_pass",             //
         "conv_eltwiseadd_affine_channel_fuse_pass",  //
         "conv_bn_fuse_pass",                         //
diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt
index 66e5522ae89e972cabed848aaf0ff1de2da71bb2..8b45a0b0311531f24cef731acd94f7bae59b4836 100644
--- a/paddle/fluid/operators/CMakeLists.txt
+++ b/paddle/fluid/operators/CMakeLists.txt
@@ -90,7 +90,7 @@ endif()
 set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
 set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
 set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col)
-set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search)
+set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc)
 set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper)
 if (WITH_GPU)
   set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc
index 275855cbb6e4c7323a0c57a8e00ed0b7fb7f8f9c..c6d98f1f9a534aa98923afc1ead0ffc1f83a8b99 100644
--- a/paddle/fluid/operators/attention_lstm_op.cc
+++ b/paddle/fluid/operators/attention_lstm_op.cc
@@ -16,7 +16,7 @@ limitations under the License. */
 #include <string>
 #include "paddle/fluid/operators/math/blas.h"
 #include "paddle/fluid/operators/math/cpu_vec.h"
-#include "paddle/fluid/operators/math/fc_compute.h"
+#include "paddle/fluid/operators/math/fc.h"
 #include "paddle/fluid/platform/cpu_info.h"
 
 namespace paddle {
@@ -339,10 +339,13 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
     T* lstm_x_data = lstm_x->mutable_data<T>(ctx.GetPlace());
     T* lstm_out_data = lstm_out->mutable_data<T>(ctx.GetPlace());
 
+    auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
+
     // x(TxM) * fc (Mx1) part of atten_wgt(M+D)x1
-    auto blas = math::GetBlas<DeviceContext, T>(ctx);
-    math::FCCompute<DeviceContext, T>(blas, total_T, 1, M, x_data, atten_w_data,
-                                      atted_x_data, atten_b_data);
+    auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
+    math::FCFunctor<DeviceContext, T> fc;
+    fc(dev_ctx, total_T, 1, M, x_data, atten_w_data, atted_x_data,
+       atten_b_data);
 
     const T* cur_atten_x_data = atted_x_data;
     const T* cur_x_data = x_data;
@@ -369,8 +372,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
         // 1d. softmax
         vec_softmax<T>(seq_len, fc_out_data, fc_out_data);
         // mul x(seq_len*M) and sum pool
-        math::FCCompute<DeviceContext, T>(blas, 1, M, seq_len, fc_out_data,
-                                          cur_x_data, lstm_x_data);
+        fc(dev_ctx, 1, M, seq_len, fc_out_data, cur_x_data, lstm_x_data);
 
         /// 2. compute LSTM step
         // lstm weight : concat[forget , input , output , tilde]
diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc
index 242f5390b806756283686dae2e2c32b93c2bd71e..bc0edd780c253ea01da26b88a4d43ee48571f2a2 100644
--- a/paddle/fluid/operators/fc_op.cc
+++ b/paddle/fluid/operators/fc_op.cc
@@ -14,65 +14,76 @@ limitations under the License. */
 
 #include "paddle/fluid/operators/fc_op.h"
 #include <vector>
-#include "paddle/fluid/operators/math/blas.h"
-#include "paddle/fluid/operators/math/fc_compute.h"
 
 namespace paddle {
 namespace operators {
 
-void FCOp::InferShape(framework::InferShapeContext* ctx) const {
-  PADDLE_ENFORCE(ctx->HasInput("Input"),
-                 "X(Input) of Fully Connected should not be null.");
-  PADDLE_ENFORCE(ctx->HasOutput("Out"),
-                 "Out(Output) of Fully Connected should not be null.");
-  PADDLE_ENFORCE(ctx->HasInput("W"),
-                 "W(Input) of Fully Connected should not be null.");
-
-  auto in_dims = ctx->GetInputDim("Input");
-  auto w_dims = ctx->GetInputDim("W");
+class FCOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+  void InferShape(framework::InferShapeContext* ctx) const override {
+    PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
+                      "X(Input) of Fully Connected should not be null.");
+    PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
+                      "Out(Output) of Fully Connected should not be null.");
+    PADDLE_ENFORCE_EQ(ctx->HasInput("W"), true,
+                      "W(Input) of Fully Connected should not be null.");
+
+    auto in_dims = ctx->GetInputDim("Input");
+    auto w_dims = ctx->GetInputDim("W");
+
+    if (ctx->HasInput("Bias")) {
+      auto bias_dims = ctx->GetInputDim("Bias");
+      if (bias_dims.size() == 2) {
+        PADDLE_ENFORCE_EQ(bias_dims[0], 1,
+                          "The shape of Bias must be [1, dim].");
+        PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
+                          "The shape of Bias must be [1, dim].");
+      } else if (bias_dims.size() == 1) {
+        PADDLE_ENFORCE_EQ(bias_dims[0], w_dims[1],
+                          "The shape of Bias must be [1, dim].");
+      }
+    }
 
-  if (ctx->HasInput("Bias")) {
-    auto bias_dims = ctx->GetInputDim("Bias");
-    if (bias_dims.size() == 2) {
-      PADDLE_ENFORCE_EQ(bias_dims[0], 1, "The shape of Bias must be [1, dim].");
-      PADDLE_ENFORCE_EQ(bias_dims[1], w_dims[1],
-                        "The shape of Bias must be [1, dim].");
-    } else if (bias_dims.size() == 1) {
-      PADDLE_ENFORCE_EQ(bias_dims[0], w_dims[1],
-                        "The shape of Bias must be [1, dim].");
+    auto& activation_type = ctx->Attrs().Get<std::string>("activation_type");
+    if (!activation_type.empty()) {
+      PADDLE_ENFORCE_EQ(activation_type, "relu",
+                        "Activation %s is not supportetd in fc now.",
+                        activation_type.c_str());
     }
-  }
+    if (ctx->Attrs().Get<bool>("use_mkldnn")) {
+      PADDLE_ENFORCE_EQ(in_dims.size() == 2 || in_dims.size() == 4, true,
+                        "Fully Connected input should be 2-D or 4-D tensor.");
+    }
+    PADDLE_ENFORCE_EQ(w_dims.size(), 2,
+                      "Fully Connected input should be 2-D tensor.");
+    int in_num_col_dims = ctx->Attrs().Get<int>("in_num_col_dims");
+    PADDLE_ENFORCE_GT(
+        in_dims.size(), in_num_col_dims,
+        "The input tensor Input's rank of FCOp should be larger than "
+        "in_num_col_dims.");
+
+    std::vector<int64_t> output_dims;
+    FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims);
 
-  if (ctx->Attrs().Get<bool>("use_mkldnn")) {
-    PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
-                   "Fully Connected input should be 2-D or 4-D tensor.");
+    ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
+    ctx->ShareLoD("Input", "Out");
   }
-  PADDLE_ENFORCE_EQ(w_dims.size(), 2,
-                    "Fully Connected input should be 2-D tensor.");
-  int in_num_col_dims = ctx->Attrs().Get<int>("in_num_col_dims");
-  PADDLE_ENFORCE_GT(
-      in_dims.size(), in_num_col_dims,
-      "The input tensor Input's rank of FCOp should be larger than "
-      "in_num_col_dims.");
-
-  std::vector<int64_t> output_dims;
-  FCOutputSize(in_dims, w_dims, output_dims, in_num_col_dims);
-
-  ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
-  ctx->ShareLoD("Input", "Out");
-}
 
-framework::OpKernelType FCOp::GetExpectedKernelType(
-    const framework::ExecutionContext& ctx) const {
-  framework::LibraryType library = framework::LibraryType::kPlain;
-  framework::DataLayout layout = framework::DataLayout::kAnyLayout;
-  if (ctx.Attr<bool>("use_mkldnn")) {
-    library = framework::LibraryType::kMKLDNN;
-    layout = framework::DataLayout::kMKLDNN;
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override {
+    framework::LibraryType library = framework::LibraryType::kPlain;
+    framework::DataLayout layout = framework::DataLayout::kAnyLayout;
+    if (ctx.Attr<bool>("use_mkldnn")) {
+      library = framework::LibraryType::kMKLDNN;
+      layout = framework::DataLayout::kMKLDNN;
+    }
+    return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
+                                   ctx.GetPlace(), layout, library);
   }
-  return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
-                                 ctx.GetPlace(), layout, library);
-}
+};
 
 void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
   auto in_dims = ctx->GetInputDim("Input");
@@ -86,8 +97,8 @@ void FCOpGrad::InferShape(framework::InferShapeContext* ctx) const {
   }
 
   if (ctx->HasInput("Bias")) {
-    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
-                   "Should have bias grad");
+    PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("Bias")), true,
+                      "Should have bias grad");
     auto bias_dims = ctx->GetInputDim("Bias");
     ctx->SetOutputDim(framework::GradVarName("Bias"), bias_dims);
   }
@@ -105,61 +116,36 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
                                  ctx.GetPlace(), layout, library);
 }
 
-void FCOpMaker::Make() {
-  AddInput("Input", "(Tensor), The input tensor of fully connected operator.");
-  AddInput("W", "(Tensor), The weight fc op with shape (I, O).");
-  AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O")
-      .AsDispensable();
-  AddAttr<int>("in_num_col_dims",
-               "(int, default 1), The fc op can take tensors with more than "
-               "two dimensions as its inputs.")
-      .SetDefault(1)
-      .EqualGreaterThan(1);
-  AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
-  AddAttr<bool>("use_mkldnn",
-                "(bool, default false) Only used in mkldnn kernel")
-      .SetDefault(false);
-  AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape,
-                "Skip calling InferShape() function in the runtime.")
-      .SetDefault(true);
-  AddComment(R"DOC(
-  Fully Connected Operator.
-
-  The fully connected operation calculates the output based on the input, weights and bias.
-  The size of each dimension of the parameters checked in the infer-shape.
-)DOC");
-}
-
-template <typename T>
-class FCOpKernel : public framework::OpKernel<T> {
+class FCOpMaker : public framework::OpProtoAndCheckerMaker {
  public:
-  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
-    PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
-                   "It must use CPUPlace.");
-    auto input = ctx.Input<framework::LoDTensor>("Input");
-    auto w = ctx.Input<Tensor>("W");
-    auto bias = ctx.Input<Tensor>("Bias");
-    auto output = ctx.Output<framework::LoDTensor>("Out");
-    int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
-    auto w_dims = w->dims();
-
-    std::vector<int64_t> output_dims;
-    FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims);
-    output->Resize(framework::make_ddim(output_dims));
-    output->set_lod(input->lod());
-
-    auto out_dims = output->dims();
-    int M = framework::product(out_dims) / w_dims[1];
-
-    const T* input_data = input->data<T>();
-    const T* w_data = w->data<T>();
-    T* output_data = output->mutable_data<T>(ctx.GetPlace());
-    auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
-    math::FCCompute<platform::CPUDeviceContext, T>(
-        blas, M, w_dims[1], w_dims[0], input_data, w_data, output_data,
-        bias ? bias->data<T>() : NULL);
-
-    // TODO(TJ): fuse act
+  void Make() override {
+    AddInput("Input",
+             "(Tensor), The input tensor of fully connected operator.");
+    AddInput("W", "(Tensor), The weight fc op with shape (I, O).");
+    AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O")
+        .AsDispensable();
+    AddOutput("Out",
+              "(Tensor) The output tensor of fully connected operator. ");
+    AddAttr<int>("in_num_col_dims",
+                 "(int, default 1), The fc op can take tensors with more than "
+                 "two dimensions as its inputs.")
+        .SetDefault(1)
+        .EqualGreaterThan(1);
+    AddAttr<std::string>("activation_type",
+                         "Avctivation type used in fully connected operator.")
+        .SetDefault("");
+    AddAttr<bool>("use_mkldnn",
+                  "(bool, default false) Only used in mkldnn kernel")
+        .SetDefault(false);
+    AddAttr<bool>(framework::kAllKernelsMustComputeRuntimeShape,
+                  "Skip calling InferShape() function in the runtime.")
+        .SetDefault(true);
+    AddComment(R"DOC(
+Fully Connected Operator.
+
+The fully connected operation calculates the output based on the input, weights and bias.
+The size of each dimension of the parameters checked in the infer-shape.
+)DOC");
   }
 };
 
@@ -170,4 +156,6 @@ namespace ops = paddle::operators;
 REGISTER_OPERATOR(fc, ops::FCOp, ops::FCOpMaker,
                   paddle::framework::DefaultGradOpDescMaker<true>);
 REGISTER_OPERATOR(fc_grad, ops::FCOpGrad);
-REGISTER_OP_CPU_KERNEL(fc, ops::FCOpKernel<float>, ops::FCOpKernel<double>);
+REGISTER_OP_CPU_KERNEL(
+    fc, ops::FCOpKernel<paddle::platform::CPUDeviceContext, float>,
+    ops::FCOpKernel<paddle::platform::CPUDeviceContext, double>);
diff --git a/paddle/fluid/operators/fc_op.cu.cc b/paddle/fluid/operators/fc_op.cu.cc
new file mode 100644
index 0000000000000000000000000000000000000000..2fd33aeb1283ec7888e83dd0f3b94af24726e741
--- /dev/null
+++ b/paddle/fluid/operators/fc_op.cu.cc
@@ -0,0 +1,20 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/fc_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_CUDA_KERNEL(
+    fc, ops::FCOpKernel<paddle::platform::CUDADeviceContext, float>,
+    ops::FCOpKernel<paddle::platform::CUDADeviceContext, double>);
diff --git a/paddle/fluid/operators/fc_op.h b/paddle/fluid/operators/fc_op.h
index b82a63cd830b569c4541bbaffb5affb75394773a..bf08e6ba6866e3929fdbe58619507ddccb7162ad 100644
--- a/paddle/fluid/operators/fc_op.h
+++ b/paddle/fluid/operators/fc_op.h
@@ -14,24 +14,16 @@ limitations under the License. */
 
 #pragma once
 
+#include <string>
+#include <vector>
 #include "paddle/fluid/framework/op_registry.h"
+#include "paddle/fluid/operators/math/fc.h"
 
 namespace paddle {
 namespace operators {
 
 using Tensor = framework::Tensor;
 
-class FCOp : public framework::OperatorWithKernel {
- public:
-  using framework::OperatorWithKernel::OperatorWithKernel;
-
-  void InferShape(framework::InferShapeContext* ctx) const override;
-
- protected:
-  framework::OpKernelType GetExpectedKernelType(
-      const framework::ExecutionContext& ctx) const override;
-};
-
 class FCOpGrad : public framework::OperatorWithKernel {
  public:
   using framework::OperatorWithKernel::OperatorWithKernel;
@@ -43,11 +35,6 @@ class FCOpGrad : public framework::OperatorWithKernel {
       const framework::ExecutionContext& ctx) const override;
 };
 
-class FCOpMaker : public framework::OpProtoAndCheckerMaker {
- public:
-  void Make() override;
-};
-
 inline void FCOutputSize(const framework::DDim& in_dims,
                          const framework::DDim& w_dims,
                          std::vector<int64_t>& out_dims,  // NOLINT
@@ -64,5 +51,38 @@ inline void FCOutputSize(const framework::DDim& in_dims,
   out_dims.push_back(w_dims[1]);
 }
 
+template <typename DeviceContext, typename T>
+class FCOpKernel : public framework::OpKernel<T> {
+ public:
+  void Compute(const paddle::framework::ExecutionContext& ctx) const override {
+    auto* input = ctx.Input<framework::LoDTensor>("Input");
+    auto* w = ctx.Input<Tensor>("W");
+    auto* bias = ctx.Input<Tensor>("Bias");
+    auto* output = ctx.Output<framework::LoDTensor>("Out");
+    int in_num_col_dims = ctx.Attr<int>("in_num_col_dims");
+    bool with_relu =
+        (ctx.Attr<std::string>("activation_type") == "relu") ? true : false;
+
+    auto w_dims = w->dims();
+
+    std::vector<int64_t> output_dims;
+    FCOutputSize(input->dims(), w_dims, output_dims, in_num_col_dims);
+    output->Resize(framework::make_ddim(output_dims));
+    output->set_lod(input->lod());
+
+    auto out_dims = output->dims();
+    int M = framework::product(out_dims) / w_dims[1];
+
+    const T* input_data = input->data<T>();
+    const T* w_data = w->data<T>();
+    T* output_data = output->mutable_data<T>(ctx.GetPlace());
+
+    auto& dev_ctx = ctx.template device_context<DeviceContext>();
+    math::FCFunctor<DeviceContext, T> fc;
+    fc(dev_ctx, M, w_dims[1], w_dims[0], input_data, w_data, output_data,
+       bias ? bias->data<T>() : NULL, with_relu);
+  }
+};
+
 }  // namespace operators
 }  // namespace paddle
diff --git a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
index 2a2c583043a26ea69745253f099eb24ccc85bb58..4c13d39406be3bb5ed6b6103032b7fe811078ca1 100644
--- a/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
+++ b/paddle/fluid/operators/fused/fused_embedding_fc_lstm_op.cc
@@ -16,7 +16,6 @@ limitations under the License. */
 #include <string>
 #include "paddle/fluid/operators/math/blas.h"
 #include "paddle/fluid/operators/math/cpu_vec.h"
-#include "paddle/fluid/operators/math/fc_compute.h"
 #include "paddle/fluid/operators/math/sequence2batch.h"
 #include "paddle/fluid/platform/cpu_info.h"
 
diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc
index e67c073b5be5e2e6d8fe20a45f91e8f623dc5d02..5c89509907375b5f2089224c21dd1ef67872c2fd 100644
--- a/paddle/fluid/operators/fused/fusion_gru_op.cc
+++ b/paddle/fluid/operators/fused/fusion_gru_op.cc
@@ -17,7 +17,7 @@ limitations under the License. */
 #include <string>
 #include "paddle/fluid/operators/jit/kernels.h"
 #include "paddle/fluid/operators/math/blas.h"
-#include "paddle/fluid/operators/math/fc_compute.h"
+#include "paddle/fluid/operators/math/fc.h"
 #include "paddle/fluid/operators/math/sequence2batch.h"
 
 namespace paddle {
@@ -219,9 +219,11 @@ class FusionGRUKernel : public framework::OpKernel<T> {
     const T* wh_state_data = wh_data + D * D2;
     T* hidden_out_data = hidden_out->mutable_data<T>(place);
     auto blas = math::GetBlas<DeviceContext, T>(ctx);
-    math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
-                                      xx_data,
-                                      bias ? bias->data<T>() : nullptr);
+
+    auto& dev_ctx = ctx.template device_context<DeviceContext>();
+    math::FCFunctor<DeviceContext, T> fc;
+    fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
+       bias ? bias->data<T>() : nullptr);
 
     int xx_offset = D3;
     int gate_offset = D;
@@ -290,17 +292,17 @@ class FusionGRUKernel : public framework::OpKernel<T> {
     auto& dev_ctx = ctx.template device_context<DeviceContext>();
     auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
     math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
+
+    math::FCFunctor<DeviceContext, T> fc;
     if (M > D3) {
-      math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
-                                        xx_data,
-                                        bias ? bias->data<T>() : nullptr);
+      fc(dev_ctx, total_T, D3, M, x_data, wx_data, xx_data,
+         bias ? bias->data<T>() : nullptr);
       to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
     } else {
       to_batch(dev_ctx, *x, xx, true, is_reverse);
       batched_input->set_lod(xx->lod());
-      math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, xx_data, wx_data,
-                                        batched_input_data,
-                                        bias ? bias->data<T>() : nullptr);
+      fc(dev_ctx, total_T, D3, M, xx_data, wx_data, batched_input_data,
+         bias ? bias->data<T>() : nullptr);
     }
 
     auto batched_lod = batched_input->lod();
diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc
index f04aa017e3fc7527054c1bb90f8427638ccc9582..32f0e37a64b98d7e184bd6522504b6821a548af4 100644
--- a/paddle/fluid/operators/fused/fusion_lstm_op.cc
+++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc
@@ -16,7 +16,7 @@ limitations under the License. */
 #include <string>
 #include "paddle/fluid/operators/jit/kernels.h"
 #include "paddle/fluid/operators/math/blas.h"
-#include "paddle/fluid/operators/math/fc_compute.h"
+#include "paddle/fluid/operators/math/fc.h"
 #include "paddle/fluid/operators/math/sequence2batch.h"
 
 namespace paddle {
@@ -281,8 +281,10 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
     T* h_out_data = hidden_out->mutable_data<T>(place);
     T* c_out_data = cell_out->mutable_data<T>(place);
     auto blas = math::GetBlas<DeviceContext, T>(ctx);
-    math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data,
-                                      xx_data, bias->data<T>());
+
+    auto& dev_ctx = ctx.template device_context<DeviceContext>();
+    math::FCFunctor<DeviceContext, T> fc;
+    fc(dev_ctx, total_T, D4, M, x_data, wx_data, xx_data, bias->data<T>());
 
     int xx_offset = D4;
     int gate_offset = D;
@@ -359,16 +361,15 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
     math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
     auto& dev_ctx = ctx.template device_context<DeviceContext>();
     auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
+    math::FCFunctor<DeviceContext, T> fc;
     if (M > D4) {
-      math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, x_data, wx_data,
-                                        xx_data, bias->data<T>());
+      fc(dev_ctx, x_dims[0], D4, M, x_data, wx_data, xx_data, bias->data<T>());
       to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
     } else {
       to_batch(dev_ctx, *x, xx, true, is_reverse);
       batched_input->set_lod(xx->lod());
-      math::FCCompute<DeviceContext, T>(blas, x_dims[0], D4, M, xx_data,
-                                        wx_data, batched_input_data,
-                                        bias->data<T>());
+      fc(dev_ctx, x_dims[0], D4, M, xx_data, wx_data, batched_input_data,
+         bias->data<T>());
     }
 
     auto batched_lod = batched_input->lod();
diff --git a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc
index 4a45177201af27709165bfc8bc881151575337b1..519670cc6a7b73b679645e5ee6d98b74613cdacc 100644
--- a/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc
+++ b/paddle/fluid/operators/fused/fusion_seqconv_eltadd_relu_op.cc
@@ -16,7 +16,7 @@ limitations under the License. */
 #include <algorithm>  // for min, max
 #include <string>
 #include "paddle/fluid/operators/math/blas.h"
-#include "paddle/fluid/operators/math/fc_compute.h"
+#include "paddle/fluid/operators/math/fc.h"
 
 namespace paddle {
 namespace operators {
@@ -209,9 +209,9 @@ class FusionSeqConvEltAddReluKernel : public framework::OpKernel<T> {
       }
     }
     auto& dev_ctx = ctx.template device_context<DeviceContext>();
-    auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
-    math::FCCompute<DeviceContext, T>(blas, x_dims[0], w_dims[1], w_dims[0],
-                                      col_data, w_data, y_data, b_data, true);
+    math::FCFunctor<DeviceContext, T> fc;
+    fc(dev_ctx, x_dims[0], w_dims[1], w_dims[0], col_data, w_data, y_data,
+       b_data, true);
   }
 };
 
diff --git a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
index 46632c1e9a4999a6e417e850874354f6f8817ba0..95a08d3b0f030e7dae6668a788b52cfe66daa250 100644
--- a/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
+++ b/paddle/fluid/operators/fused/fusion_seqexpand_concat_fc_op.cc
@@ -16,7 +16,7 @@ limitations under the License. */
 #include <string>
 #include "paddle/fluid/operators/math/blas.h"
 #include "paddle/fluid/operators/math/cpu_vec.h"
-#include "paddle/fluid/operators/math/fc_compute.h"
+#include "paddle/fluid/operators/math/fc.h"
 #include "paddle/fluid/platform/cpu_info.h"
 
 namespace paddle {
@@ -165,8 +165,11 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
     T* fc_out_data = fc_out->mutable_data<T>(ctx.GetPlace());
 
     auto blas = math::GetBlas<DeviceContext, T>(ctx);
-    math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data,
-                                      out_data, b ? b->data<T>() : NULL);
+
+    auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
+    math::FCFunctor<DeviceContext, T> fc;
+    fc(dev_ctx, total_T, D, M0, ref_in_data, w_data, out_data,
+       b ? b->data<T>() : NULL);
     w_data = w_data + M0 * D;
     // first write on
     blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data);
diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt
index 88e750214ef31216a7dd35d4859fa2266b47bf86..ca0c92b4fbee1f0275dd2a02a5b7bbfcd496b9ae 100644
--- a/paddle/fluid/operators/math/CMakeLists.txt
+++ b/paddle/fluid/operators/math/CMakeLists.txt
@@ -56,6 +56,7 @@ math_library(sequence_pooling DEPS math_function jit_kernel_helper)
 math_library(sequence_scale)
 math_library(softmax DEPS math_function jit_kernel_helper)
 math_library(beam_search DEPS math_function)
+math_library(fc DEPS blas)
 
 math_library(matrix_bit_code)
 
diff --git a/paddle/fluid/operators/math/fc.cc b/paddle/fluid/operators/math/fc.cc
new file mode 100644
index 0000000000000000000000000000000000000000..b5479a1b435682384e555c6607a097c9e0c82bd8
--- /dev/null
+++ b/paddle/fluid/operators/math/fc.cc
@@ -0,0 +1,62 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/fluid/operators/math/fc.h"
+#include "paddle/fluid/operators/jit/kernels.h"
+#include "paddle/fluid/operators/math/blas.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+template <typename T>
+class FCFunctor<platform::CPUDeviceContext, T> {
+ public:
+  void operator()(const platform::CPUDeviceContext& context, const int M,
+                  const int N, const int K, const T* X, const T* W, T* Y,
+                  const T* B = nullptr, bool relu = false) {
+    auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
+    blas.MatMul(M, N, K, X, W, Y);
+    if (B == NULL) {
+      return;
+    }
+    if (relu) {
+      auto compute =
+          jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache()
+              .At(N);
+      for (int i = 0; i < M; i++) {
+        T* dst = Y + i * N;
+        compute(B, dst, dst, N);
+      }
+    } else {
+      auto compute =
+          jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(
+              N);
+#ifdef PADDLE_WITH_MKLML
+#pragma omp parallel for
+#endif
+      for (int i = 0; i < M; i++) {
+        T* dst = Y + i * N;
+        compute(B, dst, dst, N);
+      }
+    }
+  }
+};
+
+template class FCFunctor<platform::CPUDeviceContext, float>;
+template class FCFunctor<platform::CPUDeviceContext, double>;
+
+}  // namespace math
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/math/fc.cu b/paddle/fluid/operators/math/fc.cu
new file mode 100644
index 0000000000000000000000000000000000000000..1b22b81039954bfcf8ea0f6819d778d3fa126cab
--- /dev/null
+++ b/paddle/fluid/operators/math/fc.cu
@@ -0,0 +1,73 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include <algorithm>
+#include "paddle/fluid/operators/math/blas.h"
+#include "paddle/fluid/operators/math/fc.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+template <typename T, bool DoRelu>
+__global__ void InplaceAddReluKernel(const T* bias, T* data, int M, int N) {
+  for (int i = blockIdx.x; i < M; i += gridDim.x) {
+    int index = i * N + threadIdx.x;
+    for (int j = threadIdx.x; j < N; j += blockDim.x) {
+      T tmp = data[index] + bias[j];
+      if (DoRelu) {
+        data[index] = (tmp > 0) ? tmp : 0;
+      } else {
+        data[index] = tmp;
+      }
+      index += blockDim.x;
+    }
+  }
+}
+
+template <typename T>
+class FCFunctor<platform::CUDADeviceContext, T> {
+ public:
+  void operator()(const platform::CUDADeviceContext& context, const int M,
+                  const int N, const int K, const T* X, const T* W, T* Y,
+                  const T* B = nullptr, bool relu = false) {
+    auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
+    blas.GEMM(false, false, M, N, K, static_cast<T>(1.0), X, K, W, N,
+              static_cast<T>(0.0), Y, N);
+    if (B == NULL) {
+      return;
+    }
+
+    const int kThreadsPerBlock = 1024;
+    int max_threads = context.GetMaxPhysicalThreadCount();
+    int num_threads = std::min(kThreadsPerBlock, (((N + 31) >> 5) << 5));
+    int num_blocks = std::max(max_threads / num_threads, 1);
+    if (relu) {
+      InplaceAddReluKernel<
+          T, true><<<num_blocks, num_threads, 0, context.stream()>>>(B, Y, M,
+                                                                     N);
+    } else {
+      InplaceAddReluKernel<
+          T, false><<<num_blocks, num_threads, 0, context.stream()>>>(B, Y, M,
+                                                                      N);
+    }
+  }
+};
+
+template class FCFunctor<platform::CUDADeviceContext, float>;
+template class FCFunctor<platform::CUDADeviceContext, double>;
+
+}  // namespace math
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/math/fc.h b/paddle/fluid/operators/math/fc.h
new file mode 100644
index 0000000000000000000000000000000000000000..9bef496fb9d3977b286338a79f641fde514d8303
--- /dev/null
+++ b/paddle/fluid/operators/math/fc.h
@@ -0,0 +1,34 @@
+/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include <string>
+#include "paddle/fluid/platform/device_context.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+template <typename DeviceContext, typename T>
+class FCFunctor {
+ public:
+  void operator()(const DeviceContext& context, const int M, const int N,
+                  const int K, const T* X, const T* W, T* Y,
+                  const T* B = nullptr, bool relu = false);
+};
+
+}  // namespace math
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/fluid/operators/math/fc_compute.h b/paddle/fluid/operators/math/fc_compute.h
deleted file mode 100644
index 66ce57594a14d8c94737b5dbe83af413628ef1cf..0000000000000000000000000000000000000000
--- a/paddle/fluid/operators/math/fc_compute.h
+++ /dev/null
@@ -1,55 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License. */
-
-#pragma once
-
-#include "paddle/fluid/operators/jit/kernels.h"
-#include "paddle/fluid/operators/math/blas.h"
-
-namespace paddle {
-namespace operators {
-namespace math {
-
-template <typename DeviceContext, typename T>
-inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
-                      const int N, const int K, const T* X, const T* W, T* Y,
-                      const T* B = NULL, bool relu = false) {
-  blas.MatMul(M, N, K, X, W, Y);
-  if (B == NULL) {
-    return;
-  }
-  if (relu) {
-    auto compute =
-        jit::KernelFuncs<jit::VAddReluTuple<T>, platform::CPUPlace>::Cache().At(
-            N);
-    for (int i = 0; i < M; i++) {
-      T* dst = Y + i * N;
-      compute(B, dst, dst, N);
-    }
-  } else {
-    auto compute =
-        jit::KernelFuncs<jit::VAddTuple<T>, platform::CPUPlace>::Cache().At(N);
-#ifdef PADDLE_WITH_MKLML
-#pragma omp parallel for
-#endif
-    for (int i = 0; i < M; i++) {
-      T* dst = Y + i * N;
-      compute(B, dst, dst, N);
-    }
-  }
-}
-
-}  // namespace math
-}  // namespace operators
-}  // namespace paddle
diff --git a/python/paddle/fluid/tests/unittests/test_fc_op.py b/python/paddle/fluid/tests/unittests/test_fc_op.py
index ff417ad2f16b83cd42a0603375c14450195e7fc0..6c2088af3dde213274ee068e0931df1fc699b815 100644
--- a/python/paddle/fluid/tests/unittests/test_fc_op.py
+++ b/python/paddle/fluid/tests/unittests/test_fc_op.py
@@ -17,7 +17,7 @@ import numpy as np
 from op_test import OpTest
 
 
-def fc_refer(matrix, with_bias):
+def fc_refer(matrix, with_bias, with_relu=False):
     in_n, in_c, in_h, in_w = matrix.input.shape
     w_i, w_o = matrix.weights.shape
 
@@ -31,22 +31,32 @@ def fc_refer(matrix, with_bias):
     else:
         result = np.dot(x_data, w_data)
 
-    return result
+    if with_relu:
+        return np.maximum(result, 0)
+    else:
+        return result
 
 
 class MatrixGenerate:
-    def __init__(self, mb, ic, oc, h, w):
+    def __init__(self, mb, ic, oc, h, w, bias_dims=2):
         self.input = np.random.random((mb, ic, h, w)).astype("float32")
         self.weights = np.random.random((ic * h * w, oc)).astype("float32")
-        self.bias = np.random.random((1, oc)).astype("float32")
+        if bias_dims == 2:
+            self.bias = np.random.random((1, oc)).astype("float32")
+        else:
+            self.bias = np.random.random((oc)).astype("float32")
 
 
 class TestFCOp(OpTest):
+    def config(self):
+        self.with_bias = True
+        self.with_relu = True
+        self.matrix = MatrixGenerate(1, 10, 15, 3, 3, 2)
+
     def setUp(self):
         self.op_type = "fc"
-        self.matrix = MatrixGenerate(1, 10, 15, 3, 3)
+        self.config()
 
-        self.with_bias = True
         if self.with_bias:
             self.inputs = {
                 'Input': self.matrix.input,
@@ -56,54 +66,60 @@ class TestFCOp(OpTest):
         else:
             self.inputs = {'Input': self.matrix.input, 'W': self.matrix.weights}
 
-        self.attrs = {'use_mkldnn': False}
+        if self.with_relu:
+            activation_type = "relu"
+        else:
+            activation_type = ""
+        self.attrs = {'use_mkldnn': False, 'activation_type': activation_type}
 
-        self.outputs = {'Out': fc_refer(self.matrix, self.with_bias)}
+        self.outputs = {
+            'Out': fc_refer(self.matrix, self.with_bias, self.with_relu)
+        }
 
     def test_check_output(self):
         self.check_output()
 
 
-class TestFCOpNoBias(TestFCOp):
-    def init_shapes(self, mb, ic, oc, h, w):
+class TestFCOpNoBias1(TestFCOp):
+    def config(self):
         self.with_bias = False
-        self.matrix = MatrixGenerate(mb, ic, oc, h, w)
-
-
-class TestFCOpWithBias(TestFCOp):
-    def init_shapes(self, mb, ic, oc, h, w):
-        self.with_bias = True
-        self.matrix = MatrixGenerate(mb, ic, oc, h, w)
-
+        self.with_relu = False
+        self.matrix = MatrixGenerate(2, 8, 10, 1, 1, 2)
 
-class TestFCOp1(TestFCOpNoBias):
-    def init_op_type(self):
-        self.init_shapes(2, 8, 10, 1, 1)
 
-
-class TestFCOp2(TestFCOpNoBias):
-    def init_op_type(self):
-        self.init_shapes(4, 5, 6, 2, 2)
+class TestFCOpNoBias2(TestFCOp):
+    def config(self):
+        self.with_bias = False
+        self.with_relu = False
+        self.matrix = MatrixGenerate(4, 5, 6, 2, 2, 1)
 
 
-class TestFCOp4(TestFCOpNoBias):
-    def init_op_type(self):
-        self.init_shapes(1, 32, 64, 3, 3)
+class TestFCOpNoBias4(TestFCOp):
+    def config(self):
+        self.with_bias = False
+        self.with_relu = False
+        self.matrix = MatrixGenerate(1, 32, 64, 3, 3, 1)
 
 
-class TestFCOpWithBias1(TestFCOpWithBias):
-    def init_op_type(self):
-        self.init_shapes(3, 8, 10, 2, 1)
+class TestFCOpWithBias1(TestFCOp):
+    def config(self):
+        self.with_bias = True
+        self.with_relu = False
+        self.matrix = MatrixGenerate(3, 8, 10, 2, 1, 2)
 
 
-class TestFCOpWithBias2(TestFCOpWithBias):
-    def init_op_type(self):
-        self.init_shapes(4, 5, 6, 2, 2)
+class TestFCOpWithBias2(TestFCOp):
+    def config(self):
+        self.with_bias = True
+        self.with_relu = True
+        self.matrix = MatrixGenerate(4, 5, 6, 2, 2, 1)
 
 
-class TestFCOpWithBias3(TestFCOpWithBias):
-    def init_op_type(self):
-        self.init_shapes(1, 64, 32, 3, 3)
+class TestFCOpWithBias3(TestFCOp):
+    def config(self):
+        self.with_bias = True
+        self.with_relu = True
+        self.matrix = MatrixGenerate(1, 64, 32, 3, 3, 1)
 
 
 if __name__ == "__main__":