未验证 提交 d8fc7211 编写于 作者: Z Zhang Zheng 提交者: GitHub

Fix conflict caused by wrong namespace (#39930)

上级 87b903a3
......@@ -32,6 +32,7 @@ namespace operators {
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout;
using Tensor = framework::Tensor;
namespace kps = phi::kps;
// Wrapper of log function. Use log(float32) for float16
template <typename T>
......@@ -500,7 +501,7 @@ template <typename T, typename AccT, typename LabelT, int VecSize,
bool IgnoreIndex>
__device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
T* loss, T* softmax, const T* logits, const LabelT* label, int size,
const int offset, const LogSoftmaxForwardFunctor<AccT>& func,
const int offset, const phi::LogSoftmaxForwardFunctor<AccT>& func,
const int ignore_index) {
using VecT = kps::details::VectorType<T, VecSize>;
int tid = threadIdx.x;
......@@ -583,7 +584,7 @@ template <typename T, typename AccT, typename LabelT, int VecSize,
bool IgnoreIndex>
__device__ __forceinline__ void ScalarSoftmaxForwardImpl(
T* loss, T* softmax, const T* logits, const LabelT* label, const int size,
const LogSoftmaxForwardFunctor<AccT>& func, const int ignore_index) {
const phi::LogSoftmaxForwardFunctor<AccT>& func, const int ignore_index) {
int tid = threadIdx.x;
int remain = size % (VecSize * blockDim.x);
int label_id = blockIdx.x;
......@@ -658,7 +659,7 @@ __global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits,
sum, kps::AddFunctor<AccT>());
// 3. softmax
LogSoftmaxForwardFunctor<AccT> func(max, sum);
phi::LogSoftmaxForwardFunctor<AccT> func(max, sum);
if (input_offset == output_offset) {
VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
loss, softmax, logits, label, mid_dim, input_offset, func,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册