diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu deleted file mode 100644 index a3d2e3558c52b5edd2b9c3924c5189930f2acbcf..0000000000000000000000000000000000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include <stdint.h> -#include "cross_entropy_cuda_impl.cuh" -#include "include/cuda_runtime.h" - -__global__ void CalCrossEntropyWithGradKernel(const float *softmax_logits, const float *log_softmax_logits, - const float *labels, const int batch_size, const int num_classes, - float *loss, float *dx) { - extern __shared__ float loss_shared[]; - const float mean_scale = 1.0f / static_cast<float>(batch_size); - - loss_shared[threadIdx.x] = 0; - for (int i = threadIdx.x * num_classes; i < (threadIdx.x + 1) * num_classes; ++i) { - loss_shared[threadIdx.x] -= log_softmax_logits[i] * labels[i]; - dx[i] = (softmax_logits[i] - labels[i]) * mean_scale; - } - __syncthreads(); - if (threadIdx.x == 0) { - *loss = 0; - for (int i = 0; i < batch_size; i++) { - *loss += loss_shared[i]; - } - *loss *= mean_scale; - } -} - -void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels, - const int batch_size, const int num_classes, float *loss, float *dx, - cudaStream_t cuda_stream) { - CalCrossEntropyWithGradKernel<<<1, batch_size, batch_size * sizeof(float), cuda_stream>>>( - softmax_logits, log_softmax_logits, labels, batch_size, num_classes, loss, dx); -} diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh deleted file mode 100644 index 25b1624a46d326965b756ba1cd5c717816840196..0000000000000000000000000000000000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ - -#include "device/gpu/cuda_common.h" - -void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels, - const int batch_size, const int num_classes, float *loss, float *dx, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu index 4d0503ba97d41a836cbea9701fd79c9836af9c0b..11c16581d60c41148c870f99ad544802e2ef64d2 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu @@ -52,38 +52,12 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label } template <typename T, typename S> -__global__ void CrossEntropyWithoutSparseKernel(const T *logits, const S *labels, const size_t batch_size, - const size_t class_num, T *losses) { - T epsilon = 1e-6; - for (size_t i = 0; i < batch_size; ++i) { - T logit = 0.0; - for (size_t j = 0; j < class_num; j++) { - if (fabs(labels[i * class_num + j] - 1.0) <= 1e-8) { - logit = logits[i * class_num + j]; - break; - } - } - if (logit <= 0) { - logit += epsilon; - } - losses[i] = -logf(logit); +__global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) { + losses[threadIdx.x] = 0; + for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) { + losses[threadIdx.x] -= logf(logits[i]) * labels[i]; + dlogits[i] = logits[i] - labels[i]; } - return; -} - -template <typename T, typename S> -__global__ void CrossEntropyGradWithoutSparseKernel(const T *logits, const S *labels, const size_t batch_size, - const size_t class_num, T *grad) { - for (size_t i = 0; i < batch_size; i++) { - for (size_t j = blockIdx.x * blockDim.x + threadIdx.x; j < class_num; j += blockDim.x * gridDim.x) { - if (fabs(labels[i * class_num + j] - 1.0) <= 1e-8) { - grad[i * class_num + j] = (logits[i * class_num + j] - 1) / batch_size; - } else { - grad[i * class_num + j] = logits[i * class_num + j] / batch_size; - } - } - } - return; } template <typename T, typename S> @@ -102,18 +76,9 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b } template <typename T, typename S> -void CrossEntropyWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *losses, cudaStream_t cuda_stream) { - CrossEntropyWithoutSparseKernel<<<1, 1, 0, cuda_stream>>>(logits, labels, batch_size, class_num, losses); - return; -} - -template <typename T, typename S> -void CrossEntropyGradWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *grad, cudaStream_t cuda_stream) { - CrossEntropyGradWithoutSparseKernel<<<GET_BLOCKS(class_num), GET_THREADS, 0, cuda_stream>>>( - logits, labels, batch_size, class_num, grad); - return; +void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, + T *dlogits, cudaStream_t cuda_stream) { + CrossEntropyKernel<<<1, batch_size, 0, cuda_stream>>>(logits, labels, class_num, losses, dlogits); } template void CrossEntropyWithSparse<float, int>(const float *logits, const int *labels, const size_t batch_size, @@ -126,8 +91,6 @@ template void CrossEntropyGradWithSparse<float, int>(const float *logits, const template void CrossEntropyGradWithSparse<float, int64_t>(const float *logits, const int64_t *labels, const size_t batch_size, const size_t class_num, float *grad, cudaStream_t cuda_stream); -template void CrossEntropyWithoutSparse<float, float>(const float *logits, const float *labels, const size_t batch_size, - const size_t class_num, float *losses, cudaStream_t cuda_stream); -template void CrossEntropyGradWithoutSparse<float, float>(const float *logits, const float *labels, - const size_t batch_size, const size_t class_num, float *grad, - cudaStream_t cuda_stream); +template void CrossEntropy<float, float>(const float *logits, const float *labels, const size_t batch_size, + const size_t class_num, float *losses, float *dlogits, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh index 00ec13553d39087a96c7b947e02c5c91e32f625d..54ae0728929cfb44d4ecc3d808623431f253f007 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh @@ -28,11 +28,6 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b T *grad, cudaStream_t cuda_stream); template <typename T, typename S> -void CrossEntropyWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *losses, cudaStream_t cuda_stream); - -template <typename T, typename S> -void CrossEntropyGradWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *grad, cudaStream_t cuda_stream); - +void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, + T *dlogits, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h index 3822a326fbd4f0cbaf1c6c76221f29ec493d0c7a..4d50d4753d4f2fe9ffd0eff0d1062beb7f1336c0 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h @@ -58,8 +58,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { } T *logits_addr = GetDeviceAddress<T>(inputs, 0); S *labels_addr = GetDeviceAddress<S>(inputs, 1); - T *output1_addr = GetDeviceAddress<T>(outputs, 0); - T *output2_addr = GetDeviceAddress<T>(outputs, 1); + T *loss_addr = GetDeviceAddress<T>(outputs, 0); + T *dlogits_addr = GetDeviceAddress<T>(outputs, 1); T *softmax_output_logits = GetDeviceAddress<T>(workspace, 0); const float alpha = 1; @@ -69,10 +69,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { softmax_output_descriptor_, softmax_output_logits), "cudnnSoftmaxForward failed."); - CrossEntropyWithoutSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output1_addr, - reinterpret_cast<cudaStream_t>(stream_ptr)); - CrossEntropyGradWithoutSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output2_addr, - reinterpret_cast<cudaStream_t>(stream_ptr)); + CrossEntropy(softmax_output_logits, labels_addr, batch_size_, channel_size_, loss_addr, dlogits_addr, + reinterpret_cast<cudaStream_t>(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override {