未验证 提交 3a53b77e 编写于 作者: Y Yiqun Liu 提交者: GitHub

[cherry-pick] Fix the index calculation in cross_entroy_kernel. (#53659) (#53765)

上级 3641c5ec
......@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/phi/kernels/cross_entropy_kernel.h"
#include "glog/logging.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
......@@ -468,8 +470,8 @@ __global__ void VectorizedSoftmaxForward(T* loss,
using VecT = kps::details::VectorType<T, VecSize>;
// each block deal with one batch
logits += blockIdx.x * mid_dim;
softmax += blockIdx.x * mid_dim;
logits += static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(mid_dim);
softmax += static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(mid_dim);
const int input_offset = ((uint64_t)logits) % ALIGN_BYTES / sizeof(T);
const int output_offset = ((uint64_t)softmax) % ALIGN_BYTES / sizeof(T);
......@@ -1165,6 +1167,8 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
int dim,
int D,
const int ignore_index) {
VLOG(7) << "rank=" << rank << ", axis = " << axis << ", N = " << N
<< ", dim = " << dim << ", D = " << D;
auto stream = dev_ctx.stream();
constexpr int max_dim = 320;
if (D == 1) {
......@@ -1247,11 +1251,11 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
int axis,
DenseTensor* softmax,
DenseTensor* loss) {
PADDLE_ENFORCE_EQ(
dev_ctx.GetPlace().GetType(),
AllocationType::GPU,
phi::errors::Unavailable("softmax_with_cross_entropy operator's "
"CUDA kernel only runs on GPU device."));
VLOG(7) << "logits.shape={" << logits.dims() << "}, label.shape={"
<< label.dims() << "}, soft_label=" << soft_label
<< ", use_softmax=" << use_softmax
<< ", numeric_stable_mode=" << numeric_stable_mode
<< ", ignore_index=" << ignore_index << ", axis=" << axis;
// do not with softmax op, and input is softmax
if (!use_softmax) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册