diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu
deleted file mode 100644
index a3d2e3558c52b5edd2b9c3924c5189930f2acbcf..0000000000000000000000000000000000000000
--- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu
+++ /dev/null
@@ -1,47 +0,0 @@
-/**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <stdint.h>
-#include "cross_entropy_cuda_impl.cuh"
-#include "include/cuda_runtime.h"
-
-__global__ void CalCrossEntropyWithGradKernel(const float *softmax_logits, const float *log_softmax_logits,
-                                              const float *labels, const int batch_size, const int num_classes,
-                                              float *loss, float *dx) {
-  extern __shared__ float loss_shared[];
-  const float mean_scale = 1.0f / static_cast<float>(batch_size);
-
-  loss_shared[threadIdx.x] = 0;
-  for (int i = threadIdx.x * num_classes; i < (threadIdx.x + 1) * num_classes; ++i) {
-    loss_shared[threadIdx.x] -= log_softmax_logits[i] * labels[i];
-    dx[i] = (softmax_logits[i] - labels[i]) * mean_scale;
-  }
-  __syncthreads();
-  if (threadIdx.x == 0) {
-    *loss = 0;
-    for (int i = 0; i < batch_size; i++) {
-      *loss += loss_shared[i];
-    }
-    *loss *= mean_scale;
-  }
-}
-
-void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels,
-                             const int batch_size, const int num_classes, float *loss, float *dx,
-                             cudaStream_t cuda_stream) {
-  CalCrossEntropyWithGradKernel<<<1, batch_size, batch_size * sizeof(float), cuda_stream>>>(
-    softmax_logits, log_softmax_logits, labels, batch_size, num_classes, loss, dx);
-}
diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh
deleted file mode 100644
index 25b1624a46d326965b756ba1cd5c717816840196..0000000000000000000000000000000000000000
--- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh
+++ /dev/null
@@ -1,26 +0,0 @@
-/**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
-#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
-
-#include "device/gpu/cuda_common.h"
-
-void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels,
-                             const int batch_size, const int num_classes, float *loss, float *dx,
-                             cudaStream_t cuda_stream);
-
-#endif  // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu
index 4d0503ba97d41a836cbea9701fd79c9836af9c0b..11c16581d60c41148c870f99ad544802e2ef64d2 100644
--- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu
+++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu
@@ -52,38 +52,12 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label
 }
 
 template <typename T, typename S>
-__global__ void CrossEntropyWithoutSparseKernel(const T *logits, const S *labels, const size_t batch_size,
-                                                const size_t class_num, T *losses) {
-  T epsilon = 1e-6;
-  for (size_t i = 0; i < batch_size; ++i) {
-    T logit = 0.0;
-    for (size_t j = 0; j < class_num; j++) {
-      if (fabs(labels[i * class_num + j] - 1.0) <= 1e-8) {
-        logit = logits[i * class_num + j];
-        break;
-      }
-    }
-    if (logit <= 0) {
-      logit += epsilon;
-    }
-    losses[i] = -logf(logit);
+__global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) {
+  losses[threadIdx.x] = 0;
+  for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) {
+    losses[threadIdx.x] -= logf(logits[i]) * labels[i];
+    dlogits[i] = logits[i] - labels[i];
   }
-  return;
-}
-
-template <typename T, typename S>
-__global__ void CrossEntropyGradWithoutSparseKernel(const T *logits, const S *labels, const size_t batch_size,
-                                                    const size_t class_num, T *grad) {
-  for (size_t i = 0; i < batch_size; i++) {
-    for (size_t j = blockIdx.x * blockDim.x + threadIdx.x; j < class_num; j += blockDim.x * gridDim.x) {
-      if (fabs(labels[i * class_num + j] - 1.0) <= 1e-8) {
-        grad[i * class_num + j] = (logits[i * class_num + j] - 1) / batch_size;
-      } else {
-        grad[i * class_num + j] = logits[i * class_num + j] / batch_size;
-      }
-    }
-  }
-  return;
 }
 
 template <typename T, typename S>
@@ -102,18 +76,9 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b
 }
 
 template <typename T, typename S>
-void CrossEntropyWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num,
-                               T *losses, cudaStream_t cuda_stream) {
-  CrossEntropyWithoutSparseKernel<<<1, 1, 0, cuda_stream>>>(logits, labels, batch_size, class_num, losses);
-  return;
-}
-
-template <typename T, typename S>
-void CrossEntropyGradWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num,
-                                   T *grad, cudaStream_t cuda_stream) {
-  CrossEntropyGradWithoutSparseKernel<<<GET_BLOCKS(class_num), GET_THREADS, 0, cuda_stream>>>(
-    logits, labels, batch_size, class_num, grad);
-  return;
+void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses,
+                  T *dlogits, cudaStream_t cuda_stream) {
+  CrossEntropyKernel<<<1, batch_size, 0, cuda_stream>>>(logits, labels, class_num, losses, dlogits);
 }
 
 template void CrossEntropyWithSparse<float, int>(const float *logits, const int *labels, const size_t batch_size,
@@ -126,8 +91,6 @@ template void CrossEntropyGradWithSparse<float, int>(const float *logits, const
 template void CrossEntropyGradWithSparse<float, int64_t>(const float *logits, const int64_t *labels,
                                                          const size_t batch_size, const size_t class_num, float *grad,
                                                          cudaStream_t cuda_stream);
-template void CrossEntropyWithoutSparse<float, float>(const float *logits, const float *labels, const size_t batch_size,
-                                                      const size_t class_num, float *losses, cudaStream_t cuda_stream);
-template void CrossEntropyGradWithoutSparse<float, float>(const float *logits, const float *labels,
-                                                          const size_t batch_size, const size_t class_num, float *grad,
-                                                          cudaStream_t cuda_stream);
+template void CrossEntropy<float, float>(const float *logits, const float *labels, const size_t batch_size,
+                                         const size_t class_num, float *losses, float *dlogits,
+                                         cudaStream_t cuda_stream);
diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh
index 00ec13553d39087a96c7b947e02c5c91e32f625d..54ae0728929cfb44d4ecc3d808623431f253f007 100644
--- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh
+++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh
@@ -28,11 +28,6 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b
                                 T *grad, cudaStream_t cuda_stream);
 
 template <typename T, typename S>
-void CrossEntropyWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num,
-                               T *losses, cudaStream_t cuda_stream);
-
-template <typename T, typename S>
-void CrossEntropyGradWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num,
-                                   T *grad, cudaStream_t cuda_stream);
-
+void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses,
+                  T *dlogits, cudaStream_t cuda_stream);
 #endif  // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_
diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h
index 3822a326fbd4f0cbaf1c6c76221f29ec493d0c7a..4d50d4753d4f2fe9ffd0eff0d1062beb7f1336c0 100644
--- a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h
+++ b/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h
@@ -58,8 +58,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel {
     }
     T *logits_addr = GetDeviceAddress<T>(inputs, 0);
     S *labels_addr = GetDeviceAddress<S>(inputs, 1);
-    T *output1_addr = GetDeviceAddress<T>(outputs, 0);
-    T *output2_addr = GetDeviceAddress<T>(outputs, 1);
+    T *loss_addr = GetDeviceAddress<T>(outputs, 0);
+    T *dlogits_addr = GetDeviceAddress<T>(outputs, 1);
     T *softmax_output_logits = GetDeviceAddress<T>(workspace, 0);
 
     const float alpha = 1;
@@ -69,10 +69,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel {
                           softmax_output_descriptor_, softmax_output_logits),
       "cudnnSoftmaxForward failed.");
 
-    CrossEntropyWithoutSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output1_addr,
-                              reinterpret_cast<cudaStream_t>(stream_ptr));
-    CrossEntropyGradWithoutSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output2_addr,
-                                  reinterpret_cast<cudaStream_t>(stream_ptr));
+    CrossEntropy(softmax_output_logits, labels_addr, batch_size_, channel_size_, loss_addr, dlogits_addr,
+                 reinterpret_cast<cudaStream_t>(stream_ptr));
     return true;
   }
   bool Init(const CNodePtr &kernel_node) override {