未验证 提交 d8fc7211 编写于 作者: Z Zhang Zheng 提交者: GitHub

Fix conflict caused by wrong namespace (#39930)

上级 87b903a3
...@@ -32,6 +32,7 @@ namespace operators { ...@@ -32,6 +32,7 @@ namespace operators {
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
namespace kps = phi::kps;
// Wrapper of log function. Use log(float32) for float16 // Wrapper of log function. Use log(float32) for float16
template <typename T> template <typename T>
...@@ -500,7 +501,7 @@ template <typename T, typename AccT, typename LabelT, int VecSize, ...@@ -500,7 +501,7 @@ template <typename T, typename AccT, typename LabelT, int VecSize,
bool IgnoreIndex> bool IgnoreIndex>
__device__ __forceinline__ void VectorizedSoftmaxForwardImpl( __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
T* loss, T* softmax, const T* logits, const LabelT* label, int size, T* loss, T* softmax, const T* logits, const LabelT* label, int size,
const int offset, const LogSoftmaxForwardFunctor<AccT>& func, const int offset, const phi::LogSoftmaxForwardFunctor<AccT>& func,
const int ignore_index) { const int ignore_index) {
using VecT = kps::details::VectorType<T, VecSize>; using VecT = kps::details::VectorType<T, VecSize>;
int tid = threadIdx.x; int tid = threadIdx.x;
...@@ -583,7 +584,7 @@ template <typename T, typename AccT, typename LabelT, int VecSize, ...@@ -583,7 +584,7 @@ template <typename T, typename AccT, typename LabelT, int VecSize,
bool IgnoreIndex> bool IgnoreIndex>
__device__ __forceinline__ void ScalarSoftmaxForwardImpl( __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
T* loss, T* softmax, const T* logits, const LabelT* label, const int size, T* loss, T* softmax, const T* logits, const LabelT* label, const int size,
const LogSoftmaxForwardFunctor<AccT>& func, const int ignore_index) { const phi::LogSoftmaxForwardFunctor<AccT>& func, const int ignore_index) {
int tid = threadIdx.x; int tid = threadIdx.x;
int remain = size % (VecSize * blockDim.x); int remain = size % (VecSize * blockDim.x);
int label_id = blockIdx.x; int label_id = blockIdx.x;
...@@ -658,7 +659,7 @@ __global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits, ...@@ -658,7 +659,7 @@ __global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits,
sum, kps::AddFunctor<AccT>()); sum, kps::AddFunctor<AccT>());
// 3. softmax // 3. softmax
LogSoftmaxForwardFunctor<AccT> func(max, sum); phi::LogSoftmaxForwardFunctor<AccT> func(max, sum);
if (input_offset == output_offset) { if (input_offset == output_offset) {
VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>( VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
loss, softmax, logits, label, mid_dim, input_offset, func, loss, softmax, logits, label, mid_dim, input_offset, func,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册