From d8fc7211f98c7192f5079b52f87226260c2ea03d Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Fri, 25 Feb 2022 16:13:47 +0800 Subject: [PATCH] Fix conflict caused by wrong namespace (#39930) --- paddle/fluid/operators/softmax_with_cross_entropy_op.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 92e2adb3ee..19a395e723 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -32,6 +32,7 @@ namespace operators { using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using DataLayout = platform::DataLayout; using Tensor = framework::Tensor; +namespace kps = phi::kps; // Wrapper of log function. Use log(float32) for float16 template @@ -500,7 +501,7 @@ template __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( T* loss, T* softmax, const T* logits, const LabelT* label, int size, - const int offset, const LogSoftmaxForwardFunctor& func, + const int offset, const phi::LogSoftmaxForwardFunctor& func, const int ignore_index) { using VecT = kps::details::VectorType; int tid = threadIdx.x; @@ -583,7 +584,7 @@ template __device__ __forceinline__ void ScalarSoftmaxForwardImpl( T* loss, T* softmax, const T* logits, const LabelT* label, const int size, - const LogSoftmaxForwardFunctor& func, const int ignore_index) { + const phi::LogSoftmaxForwardFunctor& func, const int ignore_index) { int tid = threadIdx.x; int remain = size % (VecSize * blockDim.x); int label_id = blockIdx.x; @@ -658,7 +659,7 @@ __global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits, sum, kps::AddFunctor()); // 3. softmax - LogSoftmaxForwardFunctor func(max, sum); + phi::LogSoftmaxForwardFunctor func(max, sum); if (input_offset == output_offset) { VectorizedSoftmaxForwardImpl( loss, softmax, logits, label, mid_dim, input_offset, func, -- GitLab