diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc
index fe0c87bc570825014222807cb90a3bb341b44e8e..2b3e7fba411e9bfe66d4070726b4e5e37af39115 100644
--- a/paddle/framework/pybind.cc
+++ b/paddle/framework/pybind.cc
@@ -31,7 +31,7 @@ limitations under the License. */
 namespace py = pybind11;
 
 USE_OP(add_two);
-USE_CPU_ONLY_OP(onehot_cross_entropy);
+USE_OP(onehot_cross_entropy);
 USE_OP(sgd);
 USE_OP(mul);
 USE_OP(mean);
diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc
index a623c551e1088365ade6f73bc6149977b6ef017e..ab1e1c101a10e09a81f7785d2f1514822e3bdf15 100644
--- a/paddle/operators/cross_entropy_op.cc
+++ b/paddle/operators/cross_entropy_op.cc
@@ -39,11 +39,10 @@ class OnehotCrossEntropyGradientOp : public framework::OperatorWithKernel {
 
  protected:
   void InferShape(const framework::InferShapeContext &ctx) const override {
-    auto X_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
+    auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
     auto X = ctx.Input<Tensor>("X");
 
-    // TODO(superjom) add enforce here after helper functions ready
-    X_grad->Resize(X->dims());
+    dX->Resize(X->dims());
   }
 };
 
@@ -70,9 +69,7 @@ namespace ops = paddle::operators;
 REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp,
             ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad,
             ops::OnehotCrossEntropyGradientOp);
-REGISTER_OP_CPU_KERNEL(
-    onehot_cross_entropy,
-    ops::OnehotCrossEntropyOpKernel<paddle::platform::CPUPlace, float>);
-REGISTER_OP_CPU_KERNEL(
-    onehot_cross_entropy_grad,
-    ops::OnehotCrossEntropyGradientOpKernel<paddle::platform::CPUPlace, float>);
+REGISTER_OP_CPU_KERNEL(onehot_cross_entropy,
+                       ops::OnehotCrossEntropyOpKernel<float>);
+REGISTER_OP_CPU_KERNEL(onehot_cross_entropy_grad,
+                       ops::OnehotCrossEntropyGradientOpKernel<float>);
diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu
index 4bbc8f093a794d46737a16488684a6a0cc25e285..2392c3d5ed98a0c406b9d9d6807a1578828f9e0f 100644
--- a/paddle/operators/cross_entropy_op.cu
+++ b/paddle/operators/cross_entropy_op.cu
@@ -12,10 +12,108 @@
    See the License for the specific language governing permissions and
    limitations under the License. */
 
-#define EIGEN_USE_GPU
-#include "paddle/operators/cross_entropy_op.h"
+#include "paddle/framework/op_registry.h"
+#include "paddle/platform/assert.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template <typename T>
+__global__ void CrossEntropyKernel(T* Y, const T* X, const int* label,
+                                   const int N, const int D) {
+  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
+  // CUDA_1D_KERNEL_LOOP(i, N) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
+       i += blockDim.x * gridDim.x) {
+    PADDLE_ASSERT(label[i] >= 0 && label[i] < D);
+    Y[i] = -log(X[i * D + label[i]]);
+  }
+}
+
+template <typename T>
+__global__ void zero(T* X, const int N) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
+       i += blockDim.x * gridDim.x) {
+    X[i] = 0.0;
+  }
+}
+
+template <typename T>
+__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
+                                           const int* label, const int N,
+                                           const int D) {
+  // TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
+  // CUDA_1D_KERNEL_LOOP(i, N) {
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
+       i += blockDim.x * gridDim.x) {
+    int idx = i * D + label[i];
+    dX[idx] = -dY[i] / X[idx];
+  }
+}
+
+template <typename T>
+class OnehotCrossEntropyOpCUDAKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
+                   "It must use GPUPlace.");
+
+    auto X = ctx.Input<Tensor>("X");
+    const T* Xdata = X->data<T>();
+    const int* label_data = ctx.Input<Tensor>("label")->data<int>();
+    auto Y = ctx.Output<Tensor>("Y");
+    Y->mutable_data<T>(ctx.GetPlace());
+    T* Ydata = Y->data<T>();
+
+    int N = X->dims()[0];
+    int D = X->dims()[1];
+    int block = 512;
+    int grid = (N + block - 1) / block;
+    // TODO(qingqing) launch kernel on specified stream
+    // base on ExecutionContext.
+    CrossEntropyKernel<T><<<grid, block>>>(Ydata, Xdata, label_data, N, D);
+  }
+};
+
+template <typename T>
+class OnehotCrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
+ public:
+  void Compute(const framework::ExecutionContext& ctx) const override {
+    PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
+                   "It must use GPUPlace.");
+
+    auto X = ctx.Input<Tensor>("X");
+    auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
+    auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
+    auto label = ctx.Input<Tensor>("label");
+
+    auto* dXdata = dX->template mutable_data<T>(ctx.GetPlace());
+    auto* dYdata = dY->template data<T>();
+    auto* Xdata = X->template data<T>();
+    auto* label_data = label->data<int>();
+
+    int N = X->dims()[0];
+    int D = X->dims()[1];
+    int block = 512;
+    int grid = (N * D + block - 1) / block;
+    // TODO(qingqing): make zero an common function.
+    zero<T><<<grid, block>>>(dXdata, N * D);
+
+    grid = (N + block - 1) / block;
+    // TODO(qingqing): launch kernel on specified stream
+    // base on ExecutionContext.
+    CrossEntropyGradientKernel<T><<<grid, block>>>(dXdata, dYdata, Xdata,
+                                                   label_data, N, D);
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
 
 namespace ops = paddle::operators;
-REGISTER_OP_GPU_KERNEL(
-    onehot_cross_entropy,
-    ops::OnehotCrossEntropyOpKernel<paddle::platform::GPUPlace, float>);
+REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
+                       ops::OnehotCrossEntropyOpCUDAKernel<float>);
+REGISTER_OP_GPU_KERNEL(onehot_cross_entropy_grad,
+                       ops::OnehotCrossEntropyGradientOpCUDAKernel<float>);
diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h
index b7df92c9a98ebf12b72a8d3d8e8e4e1a950f06c9..261cbe2d423ab5fcb7bbd75ceb1c397035a1b2bc 100644
--- a/paddle/operators/cross_entropy_op.h
+++ b/paddle/operators/cross_entropy_op.h
@@ -39,10 +39,13 @@ T tolerable_value(T x) {
   return x;
 }
 
-template <typename Place, typename T>
+template <typename T>
 class OnehotCrossEntropyOpKernel : public framework::OpKernel {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
+    PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
+                   "It must use CPUPlace.");
+
     auto X = ctx.Input<Tensor>("X");
     const T* Xdata = X->data<T>();
     const int* label_data = ctx.Input<Tensor>("label")->data<int>();
@@ -62,10 +65,13 @@ class OnehotCrossEntropyOpKernel : public framework::OpKernel {
   }
 };
 
-template <typename Place, typename T>
+template <typename T>
 class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
  public:
   void Compute(const framework::ExecutionContext& ctx) const override {
+    PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
+                   "It must use CPUPlace.");
+
     auto X = ctx.Input<Tensor>("X");
     auto dX = ctx.Output<Tensor>(framework::GradVarName("X"));
     auto dY = ctx.Input<Tensor>(framework::GradVarName("Y"));
@@ -79,6 +85,7 @@ class OnehotCrossEntropyGradientOpKernel : public framework::OpKernel {
     const int batch_size = X->dims()[0];
     const int class_num = X->dims()[1];
 
+    memset(dXdata, 0, sizeof(T) * batch_size * class_num);
     for (int i = 0; i < batch_size; ++i) {
       int index = i * class_num + label_data[i];
       dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]);
diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py
index 4815192e255c6e0429db3f50918a76a773b30131..5557e0d35820dffec9597596cfbff6ed53b7d550 100644
--- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py
+++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py
@@ -22,7 +22,7 @@ class TestCrossEntropy(unittest.TestCase):
 
 
 class CrossEntropyGradOpTest(GradientChecker):
-    def test_softmax_grad(self):
+    def test_check_grad(self):
         op = create_op("onehot_cross_entropy")
         batch_size = 100
         class_num = 10