Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
34122e3e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
34122e3e
编写于
5月 15, 2023
作者:
Z
zhangyuqin1998
提交者:
GitHub
5月 15, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move OneHotRawKernel to legacy (#53200)
* move OneHotRawKernel to legacy * fix
上级
3e90a461
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
312 addition
and
142 deletion
+312
-142
paddle/phi/kernels/cpu/one_hot_kernel.cc
paddle/phi/kernels/cpu/one_hot_kernel.cc
+30
-11
paddle/phi/kernels/gpu/one_hot_kernel.cu
paddle/phi/kernels/gpu/one_hot_kernel.cu
+18
-41
paddle/phi/kernels/legacy/cpu/one_hot_kernel.cc
paddle/phi/kernels/legacy/cpu/one_hot_kernel.cc
+88
-0
paddle/phi/kernels/legacy/gpu/one_hot_kernel.cu
paddle/phi/kernels/legacy/gpu/one_hot_kernel.cu
+96
-0
paddle/phi/kernels/legacy/xpu/one_hot_kernel.cc
paddle/phi/kernels/legacy/xpu/one_hot_kernel.cc
+68
-0
paddle/phi/kernels/one_hot_kernel.cc
paddle/phi/kernels/one_hot_kernel.cc
+0
-47
paddle/phi/kernels/one_hot_kernel.h
paddle/phi/kernels/one_hot_kernel.h
+0
-8
paddle/phi/kernels/xpu/one_hot_kernel.cc
paddle/phi/kernels/xpu/one_hot_kernel.cc
+12
-35
未找到文件。
paddle/phi/kernels/cpu/one_hot_kernel.cc
浏览文件 @
34122e3e
...
...
@@ -63,12 +63,10 @@ struct OneHotV2OpFunctor {
};
template
<
typename
T
,
typename
Context
>
void
OneHotRawKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
Scalar
&
depth
,
DataType
dtype
,
bool
allow_out_of_range
,
DenseTensor
*
out
)
{
void
OneHotKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
Scalar
&
depth
,
DenseTensor
*
out
)
{
auto
depth_v
=
depth
.
to
<
int
>
();
auto
out_dims
=
out
->
dims
();
if
(
out_dims
[
out_dims
.
size
()
-
1
]
==
-
1
)
{
...
...
@@ -76,13 +74,34 @@ void OneHotRawKernel(const Context& dev_ctx,
out
->
Resize
(
out_dims
);
}
phi
::
VisitDataType
(
dtype
,
OneHotV2OpFunctor
<
Context
,
T
>
(
&
x
,
out
,
depth_v
,
dev_ctx
));
auto
*
p_in_data
=
x
.
data
<
T
>
();
auto
numel
=
x
.
numel
();
auto
*
p_out_data
=
dev_ctx
.
template
Alloc
<
float
>(
out
);
funcs
::
set_constant
(
dev_ctx
,
out
,
0.0
);
for
(
int
i
=
0
;
i
<
numel
;
++
i
)
{
PADDLE_ENFORCE_GE
(
p_in_data
[
i
],
0
,
phi
::
errors
::
InvalidArgument
(
"Illegal index value, Input(input) value should be at least 0, "
"but received input (%d) less than 0"
,
p_in_data
[
i
]));
PADDLE_ENFORCE_LT
(
p_in_data
[
i
],
depth_v
,
phi
::
errors
::
InvalidArgument
(
"Illegal index value, Input(input) value should be less than "
"Input(depth), "
"but received input (%d) not less than depth (%d)"
,
p_in_data
[
i
],
depth_v
));
*
(
p_out_data
+
i
*
depth_v
+
p_in_data
[
i
])
=
1.0
;
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
one_hot_raw
,
CPU
,
ALL_LAYOUT
,
phi
::
OneHotRawKernel
,
int
,
int64_t
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
UNDEFINED
);
PD_REGISTER_KERNEL
(
one_hot
,
CPU
,
ALL_LAYOUT
,
phi
::
OneHotKernel
,
int
,
int64_t
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
}
paddle/phi/kernels/gpu/one_hot_kernel.cu
浏览文件 @
34122e3e
...
...
@@ -40,43 +40,11 @@ __global__ void FillOutputKernel(const InT* p_in_data,
}
}
template
<
typename
DeviceContext
,
typename
InT
>
struct
OneHotV2OpCUDAFunctor
{
const
DenseTensor
*
in_
;
DenseTensor
*
out_
;
const
DeviceContext
&
ctx_
;
int
depth_
;
OneHotV2OpCUDAFunctor
(
const
DenseTensor
*
in
,
DenseTensor
*
out
,
int
depth
,
const
DeviceContext
&
ctx
)
:
in_
(
in
),
out_
(
out
),
depth_
(
depth
),
ctx_
(
ctx
)
{}
template
<
typename
OutT
>
void
apply
()
const
{
auto
*
p_in_data
=
in_
->
data
<
InT
>
();
auto
numel
=
in_
->
numel
();
auto
*
p_out_data
=
ctx_
.
template
Alloc
<
OutT
>(
out_
);
auto
stream
=
ctx_
.
stream
();
funcs
::
set_constant
(
ctx_
,
out_
,
0.0
);
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
ctx_
,
numel
);
FillOutputKernel
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
stream
>>>
(
p_in_data
,
p_out_data
,
numel
,
depth_
);
}
};
template
<
typename
T
,
typename
Context
>
void
OneHotRawKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
Scalar
&
depth
,
DataType
dtype
,
bool
allow_out_of_range
,
DenseTensor
*
out
)
{
void
OneHotKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
Scalar
&
depth
,
DenseTensor
*
out
)
{
auto
depth_v
=
depth
.
to
<
int
>
();
auto
out_dims
=
out
->
dims
();
if
(
out_dims
[
out_dims
.
size
()
-
1
]
==
-
1
)
{
...
...
@@ -84,13 +52,22 @@ void OneHotRawKernel(const Context& dev_ctx,
out
->
Resize
(
out_dims
);
}
phi
::
VisitDataType
(
dtype
,
OneHotV2OpCUDAFunctor
<
Context
,
T
>
(
&
x
,
out
,
depth_v
,
dev_ctx
));
auto
*
p_in_data
=
x
.
data
<
T
>
();
auto
numel
=
x
.
numel
();
auto
*
p_out_data
=
dev_ctx
.
template
Alloc
<
float
>(
out
);
auto
stream
=
dev_ctx
.
stream
();
funcs
::
set_constant
(
dev_ctx
,
out
,
0.0
);
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
numel
);
FillOutputKernel
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
stream
>>>
(
p_in_data
,
p_out_data
,
numel
,
depth_v
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
one_hot_raw
,
GPU
,
ALL_LAYOUT
,
phi
::
OneHotRawKernel
,
int
,
int64_t
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
UNDEFINED
);
PD_REGISTER_KERNEL
(
one_hot
,
GPU
,
ALL_LAYOUT
,
phi
::
OneHotKernel
,
int
,
int64_t
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
}
paddle/phi/kernels/legacy/cpu/one_hot_kernel.cc
0 → 100644
浏览文件 @
34122e3e
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
phi
{
template
<
typename
DeviceContext
,
typename
InT
>
struct
OneHotV2OpFunctor
{
const
DenseTensor
*
in_
;
DenseTensor
*
out_
;
int
depth_
;
const
DeviceContext
&
ctx_
;
OneHotV2OpFunctor
(
const
DenseTensor
*
in
,
DenseTensor
*
out
,
int
depth
,
const
DeviceContext
&
ctx
)
:
in_
(
in
),
out_
(
out
),
depth_
(
depth
),
ctx_
(
ctx
)
{}
template
<
typename
OutT
>
void
apply
()
const
{
auto
*
p_in_data
=
in_
->
data
<
InT
>
();
auto
numel
=
in_
->
numel
();
auto
*
p_out_data
=
ctx_
.
template
Alloc
<
OutT
>(
out_
);
funcs
::
set_constant
(
ctx_
,
out_
,
0.0
);
for
(
int
i
=
0
;
i
<
numel
;
++
i
)
{
PADDLE_ENFORCE_GE
(
p_in_data
[
i
],
0
,
phi
::
errors
::
InvalidArgument
(
"Illegal index value, Input(input) value should be at least 0, "
"but received input (%d) less than 0"
,
p_in_data
[
i
]));
PADDLE_ENFORCE_LT
(
p_in_data
[
i
],
depth_
,
phi
::
errors
::
InvalidArgument
(
"Illegal index value, Input(input) value should be less than "
"Input(depth), "
"but received input (%d) not less than depth (%d)"
,
p_in_data
[
i
],
depth_
));
*
(
p_out_data
+
i
*
depth_
+
p_in_data
[
i
])
=
1.0
;
}
}
};
template
<
typename
T
,
typename
Context
>
void
OneHotRawKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
Scalar
&
depth
,
DataType
dtype
,
bool
allow_out_of_range
,
DenseTensor
*
out
)
{
auto
depth_v
=
depth
.
to
<
int
>
();
auto
out_dims
=
out
->
dims
();
if
(
out_dims
[
out_dims
.
size
()
-
1
]
==
-
1
)
{
out_dims
[
out_dims
.
size
()
-
1
]
=
depth_v
;
out
->
Resize
(
out_dims
);
}
phi
::
VisitDataType
(
dtype
,
OneHotV2OpFunctor
<
Context
,
T
>
(
&
x
,
out
,
depth_v
,
dev_ctx
));
}
}
// namespace phi
PD_REGISTER_KERNEL
(
one_hot_raw
,
CPU
,
ALL_LAYOUT
,
phi
::
OneHotRawKernel
,
int
,
int64_t
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
UNDEFINED
);
}
paddle/phi/kernels/legacy/gpu/one_hot_kernel.cu
0 → 100644
浏览文件 @
34122e3e
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
phi
{
using
phi
::
PADDLE_CUDA_NUM_THREADS
;
template
<
typename
InT
,
typename
OutT
>
__global__
void
FillOutputKernel
(
const
InT
*
p_in_data
,
OutT
*
p_out_data
,
const
int64_t
numel
,
const
int
depth
)
{
CUDA_KERNEL_LOOP_TYPE
(
idx
,
numel
,
int64_t
)
{
PADDLE_ENFORCE
(
p_in_data
[
idx
]
>=
0
&&
p_in_data
[
idx
]
<
depth
,
"Illegal index value, Input(input) value should be "
"greater than or equal to 0, and less than depth [%d], "
"but received [%lld]."
,
depth
,
p_in_data
[
idx
]);
*
(
p_out_data
+
(
idx
*
depth
)
+
p_in_data
[
idx
])
=
1.0
;
}
}
template
<
typename
DeviceContext
,
typename
InT
>
struct
OneHotV2OpCUDAFunctor
{
const
DenseTensor
*
in_
;
DenseTensor
*
out_
;
const
DeviceContext
&
ctx_
;
int
depth_
;
OneHotV2OpCUDAFunctor
(
const
DenseTensor
*
in
,
DenseTensor
*
out
,
int
depth
,
const
DeviceContext
&
ctx
)
:
in_
(
in
),
out_
(
out
),
depth_
(
depth
),
ctx_
(
ctx
)
{}
template
<
typename
OutT
>
void
apply
()
const
{
auto
*
p_in_data
=
in_
->
data
<
InT
>
();
auto
numel
=
in_
->
numel
();
auto
*
p_out_data
=
ctx_
.
template
Alloc
<
OutT
>(
out_
);
auto
stream
=
ctx_
.
stream
();
funcs
::
set_constant
(
ctx_
,
out_
,
0.0
);
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
ctx_
,
numel
);
FillOutputKernel
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
stream
>>>
(
p_in_data
,
p_out_data
,
numel
,
depth_
);
}
};
template
<
typename
T
,
typename
Context
>
void
OneHotRawKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
Scalar
&
depth
,
DataType
dtype
,
bool
allow_out_of_range
,
DenseTensor
*
out
)
{
auto
depth_v
=
depth
.
to
<
int
>
();
auto
out_dims
=
out
->
dims
();
if
(
out_dims
[
out_dims
.
size
()
-
1
]
==
-
1
)
{
out_dims
[
out_dims
.
size
()
-
1
]
=
depth_v
;
out
->
Resize
(
out_dims
);
}
phi
::
VisitDataType
(
dtype
,
OneHotV2OpCUDAFunctor
<
Context
,
T
>
(
&
x
,
out
,
depth_v
,
dev_ctx
));
}
}
// namespace phi
PD_REGISTER_KERNEL
(
one_hot_raw
,
GPU
,
ALL_LAYOUT
,
phi
::
OneHotRawKernel
,
int
,
int64_t
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
UNDEFINED
);
}
paddle/phi/kernels/legacy/xpu/one_hot_kernel.cc
0 → 100644
浏览文件 @
34122e3e
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace
phi
{
template
<
typename
Context
,
typename
InT
>
struct
OneHotV2OpFunctor
{
const
DenseTensor
*
in_
;
DenseTensor
*
out_
;
int
depth_
;
const
Context
&
ctx_
;
OneHotV2OpFunctor
(
const
DenseTensor
*
in
,
DenseTensor
*
out
,
int
depth
,
const
Context
&
ctx
)
:
in_
(
in
),
out_
(
out
),
depth_
(
depth
),
ctx_
(
ctx
)
{}
template
<
typename
OutT
>
void
apply
()
const
{
auto
*
p_in_data
=
in_
->
data
<
InT
>
();
auto
numel
=
in_
->
numel
();
auto
*
p_out_data
=
ctx_
.
template
Alloc
<
float
>(
out_
);
int
r
=
xpu
::
one_hot
<
InT
>
(
ctx_
.
x_context
(),
p_in_data
,
p_out_data
,
numel
,
depth_
,
1.0
,
0.0
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"one_hot"
);
}
};
template
<
typename
T
,
typename
Context
>
void
OneHotRawKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
Scalar
&
depth
,
DataType
dtype
,
bool
allow_out_of_range
,
DenseTensor
*
out
)
{
auto
depth_v
=
depth
.
to
<
int
>
();
auto
out_dims
=
out
->
dims
();
if
(
out_dims
[
out_dims
.
size
()
-
1
]
==
-
1
)
{
out_dims
[
out_dims
.
size
()
-
1
]
=
depth_v
;
out
->
Resize
(
out_dims
);
}
phi
::
VisitDataType
(
dtype
,
OneHotV2OpFunctor
<
Context
,
T
>
(
&
x
,
out
,
depth_v
,
dev_ctx
));
}
}
// namespace phi
PD_REGISTER_KERNEL
(
one_hot_raw
,
XPU
,
ALL_LAYOUT
,
phi
::
OneHotRawKernel
,
int
,
int64_t
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
UNDEFINED
);
}
paddle/phi/kernels/one_hot_kernel.cc
已删除
100644 → 0
浏览文件 @
3e90a461
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/one_hot_kernel.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
OneHotKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
Scalar
&
num_classes_s
,
DenseTensor
*
out
)
{
OneHotRawKernel
<
T
>
(
dev_ctx
,
x
,
num_classes_s
,
phi
::
DataType
::
FLOAT32
,
false
,
out
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
one_hot
,
CPU
,
ALL_LAYOUT
,
phi
::
OneHotKernel
,
int
,
int64_t
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL
(
one_hot
,
GPU
,
ALL_LAYOUT
,
phi
::
OneHotKernel
,
int
,
int64_t
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL
(
one_hot
,
XPU
,
ALL_LAYOUT
,
phi
::
OneHotKernel
,
int
,
int64_t
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
}
#endif
paddle/phi/kernels/one_hot_kernel.h
浏览文件 @
34122e3e
...
...
@@ -25,12 +25,4 @@ void OneHotKernel(const Context& dev_ctx,
const
Scalar
&
num_classes
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
OneHotRawKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
Scalar
&
depth
,
DataType
dtype
,
bool
allow_out_of_range
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/xpu/one_hot_kernel.cc
浏览文件 @
34122e3e
...
...
@@ -19,49 +19,26 @@
#include "paddle/phi/core/utils/data_type.h"
namespace
phi
{
template
<
typename
Context
,
typename
InT
>
struct
OneHotV2OpFunctor
{
const
DenseTensor
*
in_
;
DenseTensor
*
out_
;
int
depth_
;
const
Context
&
ctx_
;
OneHotV2OpFunctor
(
const
DenseTensor
*
in
,
DenseTensor
*
out
,
int
depth
,
const
Context
&
ctx
)
:
in_
(
in
),
out_
(
out
),
depth_
(
depth
),
ctx_
(
ctx
)
{}
template
<
typename
OutT
>
void
apply
()
const
{
auto
*
p_in_data
=
in_
->
data
<
InT
>
();
auto
numel
=
in_
->
numel
();
auto
*
p_out_data
=
ctx_
.
template
Alloc
<
float
>(
out_
);
int
r
=
xpu
::
one_hot
<
InT
>
(
ctx_
.
x_context
(),
p_in_data
,
p_out_data
,
numel
,
depth_
,
1.0
,
0.0
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"one_hot"
);
}
};
template
<
typename
T
,
typename
Context
>
void
OneHotRawKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
Scalar
&
depth
,
DataType
dtype
,
bool
allow_out_of_range
,
DenseTensor
*
out
)
{
void
OneHotKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
Scalar
&
depth
,
DenseTensor
*
out
)
{
auto
depth_v
=
depth
.
to
<
int
>
();
auto
out_dims
=
out
->
dims
();
if
(
out_dims
[
out_dims
.
size
()
-
1
]
==
-
1
)
{
out_dims
[
out_dims
.
size
()
-
1
]
=
depth_v
;
out
->
Resize
(
out_dims
);
}
phi
::
VisitDataType
(
dtype
,
OneHotV2OpFunctor
<
Context
,
T
>
(
&
x
,
out
,
depth_v
,
dev_ctx
));
auto
*
p_in_data
=
x
.
data
<
T
>
();
auto
numel
=
x
.
numel
();
auto
*
p_out_data
=
dev_ctx
.
template
Alloc
<
float
>(
out
);
int
r
=
xpu
::
one_hot
<
T
>
(
dev_ctx
.
x_context
(),
p_in_data
,
p_out_data
,
numel
,
depth_v
,
1.0
,
0.0
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"one_hot"
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
one_hot_raw
,
XPU
,
ALL_LAYOUT
,
phi
::
OneHotRawKernel
,
int
,
int64_t
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
UNDEFINED
);
PD_REGISTER_KERNEL
(
one_hot
,
XPU
,
ALL_LAYOUT
,
phi
::
OneHotKernel
,
int
,
int64_t
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
FLOAT32
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录