diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt
index f8b0bce6815ff17a60ef64b0eec34a7cc9d16e72..e56895c63a426b782f7b46091bc86c367d49899d 100644
--- a/paddle/operators/CMakeLists.txt
+++ b/paddle/operators/CMakeLists.txt
@@ -88,10 +88,14 @@ add_subdirectory(math)
 
 set(DEPS_OPS
     recurrent_op
-    cond_op)
+    cond_op
+    cross_entropy_op
+    softmax_with_cross_entropy_op)
 op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
   DEPS framework_proto tensor net_op)
 op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
+op_library(cross_entropy_op DEPS cross_entropy_function)
+op_library(softmax_with_cross_entropy_op DEPS cross_entropy_function softmax_function)
 
 list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
 foreach(src ${GENERAL_OPS})
diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu
index 18e44d77c9f62b296dc57952e546f844670c7d57..1cfeb7a53b047541322ac53c5b7249e660039d5c 100644
--- a/paddle/operators/cross_entropy_op.cu
+++ b/paddle/operators/cross_entropy_op.cu
@@ -12,62 +12,12 @@
    See the License for the specific language governing permissions and
    limitations under the License. */
 
-#include "paddle/framework/op_registry.h"
 #include "paddle/operators/cross_entropy_op.h"
-#include "paddle/platform/assert.h"
-#include "paddle/platform/hostdevice.h"
 
 namespace paddle {
 namespace operators {
 
-template <typename T>
-__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
-                                   const int N, const int D) {
-  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
-  // CUDA_1D_KERNEL_LOOP(i, N) {
-  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
-       i += blockDim.x * gridDim.x) {
-    PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
-    Y[i] = -TolerableValue<T>()(log(X[i * D + label[i]]));
-  }
-}
-
-template <typename T>
-__device__ __forceinline__ T sum_single_warp(T val) {
-  val += __shfl_down(val, 16);
-  val += __shfl_down(val, 8);
-  val += __shfl_down(val, 4);
-  val += __shfl_down(val, 2);
-  val += __shfl_down(val, 1);
-  return val;
-}
-
-template <typename T>
-__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
-                                       const int class_num) {
-  int tid = threadIdx.x;
-  extern __shared__ T d_sum[];
-  d_sum[tid] = 0;
-
-  int cur_idx = tid;
-  int next_idx = blockIdx.x * class_num + tid;
-  while (cur_idx < class_num) {
-    d_sum[tid] += TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
-    next_idx += blockDim.x;
-    cur_idx += blockDim.x;
-  }
-  __syncthreads();
-
-  for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
-    if (tid < stride) d_sum[tid] += d_sum[tid + stride];
-    __syncthreads();
-  }
-
-  T val = d_sum[tid];
-  val = sum_single_warp<T>(val);
-  if (tid == 0) Y[blockIdx.x] = -val;
-}
-
+namespace {
 // TODO(qingqing): make zero setting a common function.
 template <typename T>
 __global__ void Zero(T* X, const int N) {
@@ -100,6 +50,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
     dX[ids] = -label[ids] * dY[row_ids] / X[ids];
   }
 }
+}  // namespace
 
 template <typename T>
 class CrossEntropyOpCUDAKernel : public framework::OpKernel {
@@ -107,36 +58,13 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
   void Compute(const framework::ExecutionContext& ctx) const override {
     PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
                    "This kernel only runs on GPU device.");
-
     const Tensor* x = ctx.Input<Tensor>("X");
     const Tensor* label = ctx.Input<Tensor>("Label");
     Tensor* y = ctx.Output<Tensor>("Y");
+    y->mutable_data<T>(ctx.GetPlace());
 
-    const T* x_data = x->data<T>();
-    T* y_data = y->mutable_data<T>(ctx.GetPlace());
-
-    int batch_size = x->dims()[0];
-    int class_num = x->dims()[1];
-
-    if (ctx.Attr<bool>("softLabel")) {
-      auto* label_data = ctx.Input<Tensor>("Label")->data<T>();
-      int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
-
-      SoftCrossEntropyKernel<
-          T><<<batch_size, block, block * sizeof(T),
-               reinterpret_cast<const platform::CUDADeviceContext&>(
-                   ctx.device_context())
-                   .stream()>>>(y_data, x_data, label_data, class_num);
-    } else {
-      auto* label_data = ctx.Input<Tensor>("Label")->data<int>();
-      int block = 512;
-      int grid = (batch_size + block - 1) / block;
-      CrossEntropyKernel<T><<<
-          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
-                              ctx.device_context())
-                              .stream()>>>(y_data, x_data, label_data,
-                                           batch_size, class_num);
-    }
+    math::CrossEntropyFunctor<platform::GPUPlace, T>()(
+        ctx, y, x, label, ctx.Attr<bool>("softLabel"));
   }
 };
 
@@ -150,6 +78,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
     const Tensor* x = ctx.Input<Tensor>("X");
     const Tensor* label = ctx.Input<Tensor>("Label");
     Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
+    dx->mutable_data<T>(ctx.GetPlace());
 
     const T* dy_data =
         ctx.Input<Tensor>(framework::GradVarName("Y"))->data<T>();
diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h
index 255b2e9f5ea7566cca7fd3914e38da804b7c7006..1f67461d3fadb1a979832ad049d4e0098256b834 100644
--- a/paddle/operators/cross_entropy_op.h
+++ b/paddle/operators/cross_entropy_op.h
@@ -15,7 +15,7 @@ limitations under the License. */
 #pragma once
 #include "paddle/framework/eigen.h"
 #include "paddle/framework/op_registry.h"
-#include "paddle/platform/hostdevice.h"
+#include "paddle/operators/math/cross_entropy.h"
 
 namespace paddle {
 namespace operators {
@@ -25,18 +25,6 @@ template <typename T, int MajorType = Eigen::RowMajor,
           typename IndexType = Eigen::DenseIndex>
 using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
 
-template <typename T>
-struct TolerableValue {
-  HOSTDEVICE T operator()(const T& x) const {
-    PADDLE_ASSERT(std::is_floating_point<T>::value);
-    const T kApproInf = 1e20;
-
-    if (x == INFINITY) return kApproInf;
-    if (x == -INFINITY) return -kApproInf;
-    return x;
-  }
-};
-
 template <typename T>
 class CrossEntropyOpKernel : public framework::OpKernel {
  public:
@@ -46,28 +34,10 @@ class CrossEntropyOpKernel : public framework::OpKernel {
     const Tensor* x = ctx.Input<Tensor>("X");
     const Tensor* labels = ctx.Input<Tensor>("Label");
     Tensor* y = ctx.Output<Tensor>("Y");
-    T* y_data = y->mutable_data<T>(ctx.GetPlace());
-
-    const int batch_size = x->dims()[0];
-    if (ctx.Attr<bool>("softLabel")) {
-      auto prob = EigenMatrix<T>::From(*x);
-      auto lbl_mat = EigenMatrix<T>::From(*labels);
-      auto loss = EigenMatrix<T>::From(*y);
+    y->mutable_data<T>(ctx.GetPlace());
 
-      loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
-          -((lbl_mat * prob.log().unaryExpr(TolerableValue<T>()))
-                .sum(Eigen::DSizes<int, 1>(1))
-                .reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
-    } else {
-      const int class_num = x->dims()[1];
-      const T* x_data = x->data<T>();
-
-      const int* label_data = labels->data<int>();
-      for (int i = 0; i < batch_size; ++i) {
-        int index = i * class_num + label_data[i];
-        y_data[i] = -TolerableValue<T>()(std::log(x_data[index]));
-      }
-    }
+    math::CrossEntropyFunctor<platform::CPUPlace, T>()(
+        ctx, y, x, labels, ctx.Attr<bool>("softLabel"));
   }
 };
 
diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt
index f8333f34f7b4c7b0f9a0f14a7a33f9d98e1d331c..91ae3d49f1df51d9524547f7765285bff9dbb5c5 100644
--- a/paddle/operators/math/CMakeLists.txt
+++ b/paddle/operators/math/CMakeLists.txt
@@ -1,9 +1,15 @@
-
 if(WITH_GPU)
-    nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc 
-    im2col.cu DEPS cblas device_context)
+    nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc
+      im2col.cu DEPS cblas device_context operator)
+    nv_library(softmax_function SRCS softmax.cc softmax.cu
+      DEPS operator)
+    nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu
+      DEPS operator)
 else()
-    cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context)
+    cc_library(math_function SRCS math_function.cc im2col.cc
+      DEPS cblas device_context operator)
+    cc_library(softmax_function SRCS softmax.cc DEPS operator)
+    cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator)
 endif()
 
 nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
diff --git a/paddle/operators/math/cross_entropy.cc b/paddle/operators/math/cross_entropy.cc
new file mode 100644
index 0000000000000000000000000000000000000000..a5a426bc7b16852e67afd790df7a91d89a458c8a
--- /dev/null
+++ b/paddle/operators/math/cross_entropy.cc
@@ -0,0 +1,59 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#include "paddle/operators/math/cross_entropy.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+using Tensor = framework::Tensor;
+template <typename T, int MajorType = Eigen::RowMajor,
+          typename IndexType = Eigen::DenseIndex>
+using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
+
+template <typename T>
+class CrossEntropyFunctor<platform::CPUPlace, T> {
+ public:
+  void operator()(const framework::ExecutionContext& ctx,
+                  framework::Tensor* out, const framework::Tensor* prob,
+                  const framework::Tensor* labels, const bool softLabel) {
+    const int batch_size = prob->dims()[0];
+    if (softLabel) {
+      auto in = EigenMatrix<T>::From(*prob);
+      auto lbl = EigenMatrix<T>::From(*labels);
+      auto loss = EigenMatrix<T>::From(*out);
+
+      loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
+          -((lbl * in.log().unaryExpr(math::TolerableValue<T>()))
+                .sum(Eigen::DSizes<int, 1>(1))
+                .reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
+    } else {
+      const int class_num = prob->dims()[1];
+      const T* prob_data = prob->data<T>();
+      T* loss_data = out->data<T>();
+
+      const int* label_data = labels->data<int>();
+      for (int i = 0; i < batch_size; ++i) {
+        int index = i * class_num + label_data[i];
+        loss_data[i] = -math::TolerableValue<T>()(std::log(prob_data[index]));
+      }
+    }
+  }
+};
+
+template class CrossEntropyFunctor<platform::CPUPlace, float>;
+}  // namespace math
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/operators/math/cross_entropy.cu b/paddle/operators/math/cross_entropy.cu
new file mode 100644
index 0000000000000000000000000000000000000000..d14a75a30c01deb86937a3ced43005aed4066d86
--- /dev/null
+++ b/paddle/operators/math/cross_entropy.cu
@@ -0,0 +1,111 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#include "paddle/operators/math/cross_entropy.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+namespace {
+template <typename T>
+__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
+                                   const int N, const int D) {
+  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
+  // CUDA_1D_KERNEL_LOOP(i, N) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
+       i += blockDim.x * gridDim.x) {
+    PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
+    Y[i] = -math::TolerableValue<T>()(log(X[i * D + label[i]]));
+  }
+}
+
+template <typename T>
+__device__ __forceinline__ T sum_single_warp(T val) {
+  val += __shfl_down(val, 16);
+  val += __shfl_down(val, 8);
+  val += __shfl_down(val, 4);
+  val += __shfl_down(val, 2);
+  val += __shfl_down(val, 1);
+  return val;
+}
+
+template <typename T>
+__global__ void SoftCrossEntropyKernel(T* Y, const T* X, const T* label,
+                                       const int class_num) {
+  int tid = threadIdx.x;
+  extern __shared__ T d_sum[];
+  d_sum[tid] = 0;
+
+  int cur_idx = tid;
+  int next_idx = blockIdx.x * class_num + tid;
+  while (cur_idx < class_num) {
+    d_sum[tid] +=
+        math::TolerableValue<T>()(std::log(X[next_idx])) * label[next_idx];
+    next_idx += blockDim.x;
+    cur_idx += blockDim.x;
+  }
+  __syncthreads();
+
+  for (unsigned int stride = blockDim.x >> 1; stride >= 32; stride >>= 1) {
+    if (tid < stride) d_sum[tid] += d_sum[tid + stride];
+    __syncthreads();
+  }
+
+  T val = d_sum[tid];
+  val = sum_single_warp<T>(val);
+  if (tid == 0) Y[blockIdx.x] = -val;
+}
+}  // namespace
+
+using Tensor = framework::Tensor;
+
+template <typename T>
+class CrossEntropyFunctor<platform::GPUPlace, T> {
+ public:
+  void operator()(const framework::ExecutionContext& ctx,
+                  framework::Tensor* out, const framework::Tensor* prob,
+                  const framework::Tensor* labels, bool softLabel) {
+    const T* prob_data = prob->data<T>();
+    T* loss_data = out->mutable_data<T>(ctx.GetPlace());
+
+    int batch_size = prob->dims()[0];
+    int class_num = prob->dims()[1];
+
+    if (softLabel) {
+      const T* label_data = labels->data<T>();
+      int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
+
+      SoftCrossEntropyKernel<
+          T><<<batch_size, block, block * sizeof(T),
+               reinterpret_cast<const platform::CUDADeviceContext&>(
+                   ctx.device_context())
+                   .stream()>>>(loss_data, prob_data, label_data, class_num);
+    } else {
+      const int* label_data = labels->data<int>();
+      int block = 512;
+      int grid = (batch_size + block - 1) / block;
+      CrossEntropyKernel<T><<<
+          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
+                              ctx.device_context())
+                              .stream()>>>(loss_data, prob_data, label_data,
+                                           batch_size, class_num);
+    }
+  }
+};
+
+template class CrossEntropyFunctor<platform::GPUPlace, float>;
+}  // namespace math
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/operators/math/cross_entropy.h b/paddle/operators/math/cross_entropy.h
new file mode 100644
index 0000000000000000000000000000000000000000..18e637cf9186b5dc21e94f1ab15b3d858ec93c67
--- /dev/null
+++ b/paddle/operators/math/cross_entropy.h
@@ -0,0 +1,48 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#pragma once
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/operator.h"
+#include "paddle/framework/tensor.h"
+#include "paddle/platform/hostdevice.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+template <typename T>
+struct TolerableValue {
+  HOSTDEVICE T operator()(const T& x) const {
+    PADDLE_ASSERT(std::is_floating_point<T>::value);
+    const T kApproInf = 1e20;
+
+    if (x == INFINITY) return kApproInf;
+    if (x == -INFINITY) return -kApproInf;
+    return x;
+  }
+};
+
+template <typename Place, typename T>
+class CrossEntropyFunctor {
+ public:
+  // (TODO caoying) it is much better to use DeviceContext as the first
+  // parameter.
+  void operator()(const framework::ExecutionContext& context,
+                  framework::Tensor* out, const framework::Tensor* prob,
+                  const framework::Tensor* labels, const bool softLabel);
+};
+}  // namespace math
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/operators/math/softmax.cc b/paddle/operators/math/softmax.cc
new file mode 100644
index 0000000000000000000000000000000000000000..1224c05810543226773b02b5aa1b26cea15d9f54
--- /dev/null
+++ b/paddle/operators/math/softmax.cc
@@ -0,0 +1,25 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#include "paddle/operators/math/softmax.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+template class SoftmaxFunctor<platform::GPUPlace, float>;
+
+}  // namespace math
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/operators/math/softmax.cu b/paddle/operators/math/softmax.cu
new file mode 100644
index 0000000000000000000000000000000000000000..4c3df0550e7ca6f4310db1d35cc34d5c73a2dd16
--- /dev/null
+++ b/paddle/operators/math/softmax.cu
@@ -0,0 +1,27 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#define EIGEN_USE_GPU
+
+#include "paddle/operators/math/softmax.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+template class SoftmaxFunctor<platform::GPUPlace, float>;
+
+}  // namespace math
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/operators/math/softmax.h b/paddle/operators/math/softmax.h
new file mode 100644
index 0000000000000000000000000000000000000000..3d2f0d0aecffcd0fe51166c3d863aa8b91bba196
--- /dev/null
+++ b/paddle/operators/math/softmax.h
@@ -0,0 +1,73 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#pragma once
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/operator.h"
+#include "paddle/framework/tensor.h"
+
+namespace paddle {
+namespace operators {
+namespace math {
+
+template <typename T, int MajorType = Eigen::RowMajor,
+          typename IndexType = Eigen::DenseIndex>
+using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
+
+template <typename T>
+struct ValueClip {
+  HOSTDEVICE T operator()(const T& x) const {
+    const T kThreshold = -64.;
+    return x < kThreshold ? kThreshold : x;
+  }
+};
+
+template <typename Place, typename T>
+class SoftmaxFunctor {
+ public:
+  void operator()(const framework::ExecutionContext& context,
+                  const framework::Tensor* X, framework::Tensor* Y) {
+    auto logits = EigenMatrix<T>::From(*X);
+    auto softmax = EigenMatrix<T>::From(*Y);
+
+    const int kBatchDim = 0;
+    const int kClassDim = 1;
+
+    const int batch_size = logits.dimension(kBatchDim);
+    const int num_classes = logits.dimension(kClassDim);
+
+    Eigen::DSizes<int, 1> along_class(kClassDim);
+    Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
+    Eigen::DSizes<int, 2> one_by_class(1, num_classes);
+
+    auto shifted_logits = (logits -
+                           logits.maximum(along_class)
+                               .eval()
+                               .reshape(batch_by_one)
+                               .broadcast(one_by_class))
+                              .unaryExpr(ValueClip<T>());
+
+    softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
+    softmax.device(context.GetEigenDevice<Place>()) =
+        (softmax *
+         softmax.sum(along_class)
+             .inverse()
+             .eval()
+             .reshape(batch_by_one)
+             .broadcast(one_by_class));
+  }
+};
+}  // namespace math
+}  // namespace operators
+}  // namespace paddle
diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h
index 8a3a5ab927c0e2937936fcc973f000d4d95c3dbc..7220f486be055e1b841a06b15f519717c54f575c 100644
--- a/paddle/operators/softmax_op.h
+++ b/paddle/operators/softmax_op.h
@@ -15,6 +15,7 @@ limitations under the License. */
 #pragma once
 #include "paddle/framework/eigen.h"
 #include "paddle/framework/op_registry.h"
+#include "paddle/operators/math/softmax.h"
 
 namespace paddle {
 namespace operators {
@@ -30,36 +31,11 @@ class SoftmaxKernel : public framework::OpKernel {
   void Compute(const framework::ExecutionContext& context) const override {
     auto X = context.Input<Tensor>("X");
     auto Y = context.Output<Tensor>("Y");
-    Y->mutable_data<T>(context.GetPlace());
-
-    auto logits = EigenMatrix<T>::From(*X);
-    auto softmax = EigenMatrix<T>::From(*Y);
-
-    const int kBatchDim = 0;
-    const int kClassDim = 1;
 
-    const int batch_size = logits.dimension(kBatchDim);
-    const int num_classes = logits.dimension(kClassDim);
-
-    Eigen::DSizes<int, 1> along_class(kClassDim);
-    Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
-    Eigen::DSizes<int, 2> one_by_class(1, num_classes);
-
-    auto shifted_logits = (logits -
-                           logits.maximum(along_class)
-                               .eval()
-                               .reshape(batch_by_one)
-                               .broadcast(one_by_class));
-
-    softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
+    // allocate memory on device.
+    Y->mutable_data<T>(context.GetPlace());
 
-    softmax.device(context.GetEigenDevice<Place>()) =
-        (softmax *
-         softmax.sum(along_class)
-             .inverse()
-             .eval()
-             .reshape(batch_by_one)
-             .broadcast(one_by_class));
+    math::SoftmaxFunctor<Place, T>()(context, X, Y);
   }
 };
 
@@ -67,8 +43,6 @@ template <typename Place, typename T>
 class SoftmaxGradKernel : public framework::OpKernel {
  public:
   void Compute(const framework::ExecutionContext& context) const override {
-    std::shared_ptr<Tensor> scale_ = std::make_shared<Tensor>();
-
     auto Y = context.Input<Tensor>("Y");
     auto dY = context.Input<Tensor>(framework::GradVarName("Y"));
     auto dX = context.Output<Tensor>(framework::GradVarName("X"));
diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..b6f33ad9e030bc69184de2fb5b88c9dcb07e71a9
--- /dev/null
+++ b/paddle/operators/softmax_with_cross_entropy_op.cc
@@ -0,0 +1,169 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#include "paddle/operators/softmax_with_cross_entropy_op.h"
+
+namespace paddle {
+namespace operators {
+
+class SoftmaxWithCrossEntropyOpMaker
+    : public framework::OpProtoAndCheckerMaker {
+ public:
+  SoftmaxWithCrossEntropyOpMaker(framework::OpProto* proto,
+                                 framework::OpAttrChecker* op_checker)
+      : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddInput("Logits",
+             "(Tensor, default: Tensor<float>), The unscaled log probabilities "
+             "which is a 2-D tensor with shape [N x K]. N is the batch_size, "
+             "and K is the class number.")
+        .NotInGradient();
+    AddInput(
+        "Label",
+        "(Tensor, default: Tensor<int>), The ground truth which is a 2-D "
+        "tensor. "
+        "If softLable is set to 0, Label is a Tensor<int> with shape [N x 1]. "
+        "If softLable is set to 1, Label is a Tensor<float/double> "
+        "with shape [N x K].");
+    AddOutput(
+        "Softmax",
+        "(Tensor, default: Tensor<float>), A 2-D tensor with shape [N x K]. "
+        "The outputs value of softmax activation by given the input batch, "
+        "which will be used in backward calculation.")
+        .AsIntermediate();
+    AddOutput("Loss",
+              "(Tensor, default: Tensor<float>), A 2-D tensor. The cross "
+              "entropy loss with shape [N x 1].");
+    AddAttr<bool>(
+        "softLabel",
+        "(bool, default: false), A flag to indicate whether to interpretate "
+        "the given labels as soft labels.")
+        .SetDefault(false);
+    AddComment(R"DOC(
+Cross entropy loss with softmax are used as the output layer extensively. This
+operator computes the softmax normalized values for each row of the input
+tensor, after which cross-entropy loss is then computed. This provides a more
+numerically stable gradient.
+
+Because this operators performs a softmax on logits internally, it expects
+unscaled logits. Please do not call this op with the output of softmax operator,
+which will produce incorrect results.
+
+This operators expects mutually exclusive hard labels, each sample in a batch
+is in exactly one class with probabilities 1. Each sample in the batch with one
+and only one label.
+
+Equation:
+
+1) hard label (one-hot label)
+
+Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), j = 1, ..., K
+
+2) soft label (a distribution over all classes)
+
+Loss_j = -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i-\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), j = 1,...,K
+
+)DOC");
+  }
+};
+
+class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(const framework::InferShapeContext& ctx) const override {
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Logits"),
+                            "Input(Logits) should be not null.");
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
+                            "Input(Label) should be not null.");
+
+    PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Softmax"),
+                            "Output(Softmax) should be not null.");
+    PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Loss"),
+                            "Output(Loss) should be not null.");
+
+    const Tensor* logits = ctx.Input<Tensor>("Logits");
+    const Tensor* labels = ctx.Input<Tensor>("Label");
+    PADDLE_ENFORCE_EQ(
+        logits->dims().size(), 2UL,
+        "The input of softmax_with_cross_entropy should be a 2-D tensor.");
+    PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL,
+                      "The labels should be a 2-D tensor.");
+
+    if (ctx.Attr<bool>("softLabel")) {
+      PADDLE_ENFORCE_EQ(logits->dims()[1], labels->dims()[1],
+                        "If Attr(softLabel) == true, the 2nd dimension of "
+                        "Input(X) and Input(Label) should be equal.");
+    } else {
+      PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL,
+                        "If Attr(softLabel) == false, the 2nd dimension of "
+                        "Input(Label) should be 1.");
+    }
+
+    ctx.Output<framework::Tensor>("Softmax")->Resize(logits->dims());
+    ctx.Output<framework::Tensor>("Loss")->Resize({logits->dims()[0], 1});
+
+    ctx.ShareLoD("Logits", /*->*/ "Softmax");
+    ctx.ShareLoD("Logits", /*->*/ "Loss");
+  }
+};
+
+class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
+ public:
+  using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+  void InferShape(const framework::InferShapeContext& ctx) const override {
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Loss")),
+                            "Input(Loss@Grad) should not be null.");
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Softmax"),
+                            "Input(Softmax) should be not null.");
+    PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
+                            "Input(Label) should be not null.");
+    PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar(framework::GradVarName("Logits")),
+                            "Output(Logits@Grad) should be not null.");
+
+    const Tensor* softmax = ctx.Input<Tensor>("Softmax");
+    const Tensor* labels = ctx.Input<Tensor>("Label");
+    PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("Label")->dims().size(), 2UL,
+                      "The labels should be a 2-D tensor.");
+
+    if (ctx.Attr<bool>("softLabel")) {
+      PADDLE_ENFORCE_EQ(softmax->dims()[1], labels->dims()[1],
+                        "When Attr(softLabel) == true, the 2nd dimension of "
+                        "Input(X) and Input(Label) should be equal.");
+    } else {
+      PADDLE_ENFORCE_EQ(labels->dims()[1], 1UL,
+                        "When Attr(softLabel) == false, the 2nd dimension of "
+                        "Input(Label) should be 1.");
+    }
+
+    ctx.Output<framework::LoDTensor>(framework::GradVarName("Logits"))
+        ->Resize(ctx.Input<Tensor>("Softmax")->dims());
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+
+REGISTER_OP(softmax_with_cross_entropy, ops::SoftmaxWithCrossEntropyOp,
+            ops::SoftmaxWithCrossEntropyOpMaker,
+            softmax_with_cross_entropy_grad,
+            ops::SoftmaxWithCrossEntropyOpGrad);
+REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy,
+                       ops::SoftmaxWithCrossEntropyKernel<float>);
+REGISTER_OP_CPU_KERNEL(softmax_with_cross_entropy_grad,
+                       ops::SoftmaxWithCrossEntropyGradKernel<float>);
diff --git a/paddle/operators/softmax_with_cross_entropy_op.cu b/paddle/operators/softmax_with_cross_entropy_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..1cf4296dccf68aece6fdfb7910a9c68449633b76
--- /dev/null
+++ b/paddle/operators/softmax_with_cross_entropy_op.cu
@@ -0,0 +1,119 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#define EIGEN_USE_GPU
+
+#include "paddle/operators/softmax_with_cross_entropy_op.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+namespace {
+template <typename T>
+__global__ void CrossEntropyGrad(T* out_grad, const T* in_grad,
+                                 const int* labels, const int batch_size,
+                                 const int class_num) {
+  int tid = blockIdx.x * blockDim.x + threadIdx.x;
+  int sample_idx = tid / class_num;
+
+  if (tid < batch_size * class_num) out_grad[tid] *= in_grad[sample_idx];
+  __syncthreads();
+
+  if (tid < batch_size) {
+    PADDLE_ASSERT(labels[sample_idx] >= 0 && labels[sample_idx] < class_num);
+    out_grad[tid * class_num + labels[tid]] -= 1.;
+  }
+}
+
+template <typename T>
+__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
+                                               const T* loss_grad,
+                                               const T* labels,
+                                               const int batch_size,
+                                               const int class_num) {
+  int ids = blockIdx.x * blockDim.x + threadIdx.x;
+  if (ids < batch_size * class_num) {
+    int row_ids = ids / class_num;
+    logit_grad[ids] = logit_grad[ids] * loss_grad[row_ids] - labels[ids];
+  }
+}
+}  // namespace
+
+template <typename T>
+class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
+                   "This kernel only runs on GPU device.");
+    const Tensor* logits = context.Input<Tensor>("Logits");
+    const Tensor* labels = context.Input<Tensor>("Label");
+    Tensor* softmax = context.Output<Tensor>("Softmax");
+
+    Tensor* loss = context.Output<Tensor>("Loss");
+    softmax->mutable_data<T>(context.GetPlace());
+    loss->mutable_data<T>(context.GetPlace());
+
+    math::SoftmaxFunctor<platform::GPUPlace, T>()(context, logits, softmax);
+    math::CrossEntropyFunctor<platform::GPUPlace, T>()(
+        context, loss, softmax, labels, context.Attr<bool>("softLabel"));
+  }
+};
+
+template <typename T>
+class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
+                   "This kernel only runs on GPU device.");
+    const Tensor* labels = context.Input<Tensor>("Label");
+    const T* loss_grad_data =
+        context.Input<Tensor>(framework::GradVarName("Loss"))->data<T>();
+    Tensor* logit_grad =
+        context.Output<Tensor>(framework::GradVarName("Logits"));
+    logit_grad->ShareDataWith<T>(*context.Input<Tensor>("Softmax"));
+    T* logit_grad_data = logit_grad->data<T>();
+
+    const int batch_size = logit_grad->dims()[0];
+    const int class_num = logit_grad->dims()[1];
+    int block = 512;
+    int grid = (batch_size * class_num + block - 1) / block;
+
+    if (context.Attr<bool>("softLabel")) {
+      const T* label_data = labels->data<T>();
+      SoftCrossEntropyGradientKernel<T><<<
+          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
+                              context.device_context())
+                              .stream()>>>(logit_grad_data, loss_grad_data,
+                                           label_data, batch_size, class_num);
+    } else {
+      const int* label_data = labels->data<int>();
+      CrossEntropyGrad<T><<<
+          grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
+                              context.device_context())
+                              .stream()>>>(logit_grad_data, loss_grad_data,
+                                           label_data, batch_size, class_num);
+    }
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy,
+                       ops::SoftmaxWithCrossEntropyCUDAKernel<float>);
+REGISTER_OP_GPU_KERNEL(softmax_with_cross_entropy_grad,
+                       ops::SoftmaxWithCrossEntropyGradCUDAKernel<float>);
diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..bf792c1f59e2e43a98c93bddbc2aa63d646dee6f
--- /dev/null
+++ b/paddle/operators/softmax_with_cross_entropy_op.h
@@ -0,0 +1,86 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#pragma once
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+#include "paddle/operators/math/cross_entropy.h"
+#include "paddle/operators/math/softmax.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+template <typename T, int MajorType = Eigen::RowMajor,
+          typename IndexType = Eigen::DenseIndex>
+using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
+
+template <typename T>
+class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()),
+                   "This kernel only runs on CPU.");
+    const Tensor* logits = context.Input<Tensor>("Logits");
+    const Tensor* labels = context.Input<Tensor>("Label");
+    Tensor* softmax = context.Output<Tensor>("Softmax");
+    Tensor* loss = context.Output<Tensor>("Loss");
+
+    softmax->mutable_data<T>(context.GetPlace());
+    loss->mutable_data<T>(context.GetPlace());
+
+    math::SoftmaxFunctor<platform::CPUPlace, T>()(context, logits, softmax);
+    math::CrossEntropyFunctor<platform::CPUPlace, T>()(
+        context, loss, softmax, labels, context.Attr<bool>("softLabel"));
+  }
+};
+
+template <typename T>
+class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& context) const override {
+    const Tensor* out_grad =
+        context.Input<Tensor>(framework::GradVarName("Loss"));
+    const Tensor* labels = context.Input<Tensor>("Label");
+    Tensor* logit_grad =
+        context.Output<Tensor>(framework::GradVarName("Logits"));
+    logit_grad->ShareDataWith<T>(*context.Input<Tensor>("Softmax"));
+
+    const int class_num = logit_grad->dims()[1];
+    if (context.Attr<bool>("softLabel")) {
+      auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
+      auto logit_grad_mat = EigenMatrix<T>::From(*logit_grad);
+      auto lbl_mat = EigenMatrix<T>::From(*labels);
+
+      logit_grad_mat.device(context.GetEigenDevice<platform::CPUPlace>()) =
+          logit_grad_mat *
+              out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) -
+          lbl_mat;
+    } else {
+      const int batch_size = logit_grad->dims()[0];
+      const int* label_data = labels->data<int>();
+      const T* out_grad_data = out_grad->data<T>();
+      T* logit_grad_data = logit_grad->data<T>();
+
+      for (int i = 0; i < batch_size; ++i) {
+        int index = i * class_num + label_data[i];
+        logit_grad_data[index] =
+            (out_grad_data[i] * logit_grad_data[index] - 1.);
+      }
+    }
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py
index 0a5673868c547d9e184e8ce05346c3ebabe06892..579ad7b40738f45bf055f740e66d2238f4db22fc 100644
--- a/python/paddle/v2/framework/tests/op_test.py
+++ b/python/paddle/v2/framework/tests/op_test.py
@@ -177,7 +177,7 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place,
 
 
 class OpTest(unittest.TestCase):
-    def check_output_with_place(self, place):
+    def check_output_with_place(self, place, atol):
         self.scope = core.Scope()
         op_inputs = self.inputs if hasattr(self, "inputs") else dict()
         op_outputs = self.outputs if hasattr(self, "outputs") else dict()
@@ -206,22 +206,23 @@ class OpTest(unittest.TestCase):
                         self.scope.find_var(sub_out_name).get_tensor())
                     self.assertTrue(
                         np.allclose(
-                            actual, expect, atol=1e-05),
-                        "output name: " + out_name + " has diff")
+                            actual, expect, atol=atol),
+                        "output name: " + out_name + " has diff.")
             else:
                 actual = np.array(self.scope.find_var(out_name).get_tensor())
                 expect = self.outputs[out_name]
+
                 self.assertTrue(
                     np.allclose(
-                        actual, expect, atol=1e-05),
-                    "output name: " + out_name + " has diff")
+                        actual, expect, atol=atol),
+                    "output name: " + out_name + " has diff.")
 
-    def check_output(self):
+    def check_output(self, atol=1e-5):
         places = [core.CPUPlace()]
         if core.is_compile_gpu():
             places.append(core.GPUPlace(0))
         for place in places:
-            self.check_output_with_place(place)
+            self.check_output_with_place(place, atol)
 
     def __assert_is_close(self, numeric_grads, analytic_grads, names,
                           max_relative_error, msg_prefix):
@@ -235,9 +236,10 @@ class OpTest(unittest.TestCase):
 
             def err_msg():
                 offset = np.argmax(diff_mat > max_relative_error)
-                return "%s Variable %s max gradient diff %f over limit %f, the first " \
-                  "error element is %d" % (
-                   msg_prefix, name, max_diff, max_relative_error, offset)
+                return ("%s Variable %s max gradient diff %f over limit %f, "
+                        "the first error element is %d") % (
+                            msg_prefix, name, max_diff, max_relative_error,
+                            offset)
 
             self.assertLessEqual(max_diff, max_relative_error, err_msg())
 
diff --git a/python/paddle/v2/framework/tests/test_softmax_op.py b/python/paddle/v2/framework/tests/test_softmax_op.py
index 1b948f252fa631e9886840b377de2996e110dc91..b41c810d9a6269c934a434b085748a86deccb475 100644
--- a/python/paddle/v2/framework/tests/test_softmax_op.py
+++ b/python/paddle/v2/framework/tests/test_softmax_op.py
@@ -5,7 +5,7 @@ from op_test import OpTest
 
 def stable_softmax(x):
     """Compute the softmax of vector x in a numerically stable way."""
-    shiftx = x - np.max(x)
+    shiftx = x - np.max(x).clip(-64.)
     exps = np.exp(shiftx)
     return exps / np.sum(exps)
 
diff --git a/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..428395b76c8fbcbc07b19ee1979419f0e64aca85
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_softmax_with_cross_entropy_op.py
@@ -0,0 +1,70 @@
+import unittest
+import numpy as np
+
+from op_test import OpTest
+from test_softmax_op import stable_softmax
+
+
+class TestSoftmaxWithCrossEntropyOp(OpTest):
+    """
+    Test softmax with cross entropy operator with discreate one-hot labels.
+    """
+
+    def setUp(self):
+        self.op_type = "softmax_with_cross_entropy"
+        batch_size = 3
+        class_num = 37
+
+        logits = np.random.uniform(0.1, 1.0,
+                                   [batch_size, class_num]).astype("float32")
+        softmax = np.apply_along_axis(stable_softmax, 1, logits)
+        labels = np.random.randint(0, class_num, [batch_size, 1], dtype="int32")
+
+        cross_entropy = np.asmatrix(
+            [[-np.log(softmax[i][labels[i][0]])]
+             for i in range(softmax.shape[0])],
+            dtype="float32")
+
+        self.inputs = {"Logits": logits, "Label": labels}
+        self.outputs = {"Softmax": softmax, "Loss": cross_entropy}
+
+    def test_check_output(self):
+        self.check_output()
+
+    def test_check_grad(self):
+        self.check_grad(["Logits"], "Loss", max_relative_error=0.05)
+
+
+class TestSoftmaxWithCrossEntropyOp2(OpTest):
+    """
+    Test softmax with cross entropy operator with soft labels.
+    """
+
+    def setUp(self):
+        self.op_type = "softmax_with_cross_entropy"
+        batch_size = 2
+        class_num = 17
+
+        logits = np.random.uniform(0.1, 1.0,
+                                   [batch_size, class_num]).astype("float32")
+        softmax = np.apply_along_axis(stable_softmax, 1, logits)
+        labels = np.random.uniform(0.1, 1.0,
+                                   [batch_size, class_num]).astype("float32")
+        labels /= np.sum(labels, axis=1, keepdims=True)
+
+        cross_entropy = (-labels * np.log(softmax)).sum(
+            axis=1, keepdims=True).astype("float32")
+
+        self.inputs = {"Logits": logits, "Label": labels}
+        self.outputs = {"Softmax": softmax, "Loss": cross_entropy}
+        self.attrs = {"softLabel": True}
+
+    def test_check_output(self):
+        self.check_output()
+
+    def test_check_grad(self):
+        self.check_grad(["Logits"], "Loss", max_relative_error=0.05)
+
+
+if __name__ == "__main__":
+    unittest.main()