Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
85f8fd9b
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
85f8fd9b
编写于
3月 15, 2022
作者:
Z
Zhang Zheng
提交者:
GitHub
3月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi]Move searchsorted kernel to phi (#40520)
上级
1a32391c
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
124 addition
and
73 deletion
+124
-73
paddle/fluid/operators/searchsorted_op.cc
paddle/fluid/operators/searchsorted_op.cc
+1
-9
paddle/phi/kernels/cpu/searchsorted_kernel.cc
paddle/phi/kernels/cpu/searchsorted_kernel.cc
+28
-0
paddle/phi/kernels/gpu/searchsorted_kernel.cu
paddle/phi/kernels/gpu/searchsorted_kernel.cu
+28
-0
paddle/phi/kernels/impl/searchsorted_kernel_impl.h
paddle/phi/kernels/impl/searchsorted_kernel_impl.h
+52
-55
paddle/phi/kernels/searchsorted_kernel.h
paddle/phi/kernels/searchsorted_kernel.h
+15
-9
未找到文件。
paddle/fluid/operators/searchsorted_op.cc
浏览文件 @
85f8fd9b
...
@@ -12,8 +12,7 @@
...
@@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/searchsorted_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -117,10 +116,3 @@ class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -117,10 +116,3 @@ class SearchSortedOpMaker : public framework::OpProtoAndCheckerMaker {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
searchsorted
,
ops
::
SearchSortedOp
,
ops
::
SearchSortedOpMaker
);
REGISTER_OPERATOR
(
searchsorted
,
ops
::
SearchSortedOp
,
ops
::
SearchSortedOpMaker
);
REGISTER_OP_CPU_KERNEL
(
searchsorted
,
ops
::
SearchSortedKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SearchSortedKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
SearchSortedKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
SearchSortedKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
paddle/phi/kernels/cpu/searchsorted_kernel.cc
0 → 100644
浏览文件 @
85f8fd9b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/searchsorted_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/searchsorted_kernel_impl.h"
PD_REGISTER_KERNEL
(
searchsorted
,
CPU
,
ALL_LAYOUT
,
phi
::
SearchsortedKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/phi/kernels/gpu/searchsorted_kernel.cu
0 → 100644
浏览文件 @
85f8fd9b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/searchsorted_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/searchsorted_kernel_impl.h"
PD_REGISTER_KERNEL
(
searchsorted
,
GPU
,
ALL_LAYOUT
,
phi
::
SearchsortedKernel
,
float
,
double
,
int
,
int64_t
)
{}
paddle/
fluid/operators/searchsorted_op
.h
→
paddle/
phi/kernels/impl/searchsorted_kernel_impl
.h
浏览文件 @
85f8fd9b
// Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -16,16 +16,11 @@
...
@@ -16,16 +16,11 @@
#include <math.h>
#include <math.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/algorithm.h"
#include "paddle/phi/kernels/funcs/for_range.h"
namespace
paddle
{
namespace
phi
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T1
,
typename
T2
,
typename
OutType
>
template
<
typename
T1
,
typename
T2
,
typename
OutType
>
class
GpuAndCpuSearchSortedCompute
{
class
GpuAndCpuSearchSortedCompute
{
...
@@ -65,9 +60,11 @@ class GpuAndCpuSearchSortedCompute {
...
@@ -65,9 +60,11 @@ class GpuAndCpuSearchSortedCompute {
static
HOSTDEVICE
bool
IsInf
(
int64_t
x
)
{
return
false
;
}
static
HOSTDEVICE
bool
IsInf
(
int64_t
x
)
{
return
false
;
}
HOSTDEVICE
GpuAndCpuSearchSortedCompute
(
const
T1
*
sequence_data
,
HOSTDEVICE
GpuAndCpuSearchSortedCompute
(
const
T1
*
sequence_data
,
const
T2
*
value_data
,
bool
right
,
const
T2
*
value_data
,
bool
right
,
bool
is_1d_boundaries
,
bool
is_1d_boundaries
,
int64_t
val_size
,
int64_t
seq_size
,
int64_t
val_size
,
int64_t
seq_size
,
OutType
*
out_data
)
OutType
*
out_data
)
:
sequence_data_
(
sequence_data
),
:
sequence_data_
(
sequence_data
),
value_data_
(
value_data
),
value_data_
(
value_data
),
...
@@ -104,12 +101,13 @@ class GpuAndCpuSearchSortedCompute {
...
@@ -104,12 +101,13 @@ class GpuAndCpuSearchSortedCompute {
OutType
*
out_data_
;
OutType
*
out_data_
;
};
};
template
<
typename
Device
Context
,
typename
T1
,
typename
OutType
>
template
<
typename
Context
,
typename
T1
,
typename
OutType
>
class
SearchSortedFunctor
{
class
SearchSortedFunctor
{
public:
public:
SearchSortedFunctor
(
const
framework
::
ExecutionContext
&
context
,
SearchSortedFunctor
(
const
Context
&
context
,
const
framework
::
Tensor
*
sorted_sequence
,
const
DenseTensor
*
sorted_sequence
,
const
framework
::
Tensor
*
value
,
bool
right
,
const
DenseTensor
*
value
,
bool
right
,
OutType
*
out_data
)
OutType
*
out_data
)
:
context_
(
context
),
:
context_
(
context
),
sorted_sequence_
(
sorted_sequence
),
sorted_sequence_
(
sorted_sequence
),
...
@@ -121,74 +119,73 @@ class SearchSortedFunctor {
...
@@ -121,74 +119,73 @@ class SearchSortedFunctor {
void
apply
()
{
void
apply
()
{
const
T1
*
sequence_data
=
sorted_sequence_
->
data
<
T1
>
();
const
T1
*
sequence_data
=
sorted_sequence_
->
data
<
T1
>
();
const
T2
*
value_data
=
value_
->
data
<
T2
>
();
const
T2
*
value_data
=
value_
->
data
<
T2
>
();
const
framework
::
DDim
&
seq_dims
=
sorted_sequence_
->
dims
();
const
phi
::
DDim
&
seq_dims
=
sorted_sequence_
->
dims
();
const
framework
::
DDim
&
val_dims
=
value_
->
dims
();
const
phi
::
DDim
&
val_dims
=
value_
->
dims
();
bool
is_1d_boundaries
=
seq_dims
.
size
()
==
1
;
bool
is_1d_boundaries
=
seq_dims
.
size
()
==
1
;
int64_t
val_size
=
val_dims
[
val_dims
.
size
()
-
1
];
int64_t
val_size
=
val_dims
[
val_dims
.
size
()
-
1
];
int64_t
seq_size
=
seq_dims
[
seq_dims
.
size
()
-
1
];
int64_t
seq_size
=
seq_dims
[
seq_dims
.
size
()
-
1
];
auto
&
dev_ctx
=
context_
.
template
device_context
<
DeviceContext
>();
funcs
::
ForRange
<
Context
>
for_range
(
context_
,
value_
->
numel
());
platform
::
ForRange
<
DeviceContext
>
for_range
(
dev_ctx
,
value_
->
numel
());
GpuAndCpuSearchSortedCompute
<
T1
,
T2
,
OutType
>
GpuAndCpuSearchSortedCompute
<
T1
,
T2
,
OutType
>
gpu_and_cpu_search_sorted_compute
(
sequence_data
,
value_data
,
right_
,
gpu_and_cpu_search_sorted_compute
(
sequence_data
,
is_1d_boundaries
,
val_size
,
seq_size
,
value_data
,
right_
,
is_1d_boundaries
,
val_size
,
seq_size
,
out_data_
);
out_data_
);
for_range
(
gpu_and_cpu_search_sorted_compute
);
for_range
(
gpu_and_cpu_search_sorted_compute
);
}
}
private:
private:
const
framework
::
Execution
Context
&
context_
;
const
Context
&
context_
;
const
framework
::
Tensor
*
sorted_sequence_
;
const
Dense
Tensor
*
sorted_sequence_
;
const
framework
::
Tensor
*
value_
;
const
Dense
Tensor
*
value_
;
bool
right_
;
bool
right_
;
OutType
*
out_data_
;
OutType
*
out_data_
;
};
};
template
<
typename
Visitor
>
template
<
typename
Visitor
>
static
void
VisitDataType
(
framework
::
proto
::
VarType
::
Type
type
,
static
void
VisitDataType
(
DataType
type
,
Visitor
visitor
)
{
Visitor
visitor
)
{
if
(
type
==
DataType
::
FLOAT32
)
{
if
(
type
==
framework
::
proto
::
VarType
::
FP32
)
{
visitor
.
template
apply
<
float
>();
visitor
.
template
apply
<
float
>();
}
else
if
(
type
==
framework
::
proto
::
VarType
::
FP
64
)
{
}
else
if
(
type
==
DataType
::
FLOAT
64
)
{
visitor
.
template
apply
<
double
>();
visitor
.
template
apply
<
double
>();
}
else
if
(
type
==
framework
::
proto
::
Var
Type
::
INT32
)
{
}
else
if
(
type
==
Data
Type
::
INT32
)
{
visitor
.
template
apply
<
int
>();
visitor
.
template
apply
<
int
>();
}
else
if
(
type
==
framework
::
proto
::
Var
Type
::
INT64
)
{
}
else
if
(
type
==
Data
Type
::
INT64
)
{
visitor
.
template
apply
<
int64_t
>();
visitor
.
template
apply
<
int64_t
>();
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
errors
::
InvalidArgument
(
"The recieved values data type %s can not meet input requirements. "
"The recieved values data type %s can not meet input requirements. "
"Because the given values data type of searchsorted operators must be "
"Because the given values data type of searchsorted operators must be "
"float32, float64, int32 or int64. Please input appropriate "
"float32, float64, int32 or int64. Please input appropriate "
"sorted_sequence again! "
,
"sorted_sequence again! "
,
framework
::
DataTypeToString
(
type
)
));
type
));
}
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
T
,
typename
Context
>
class
SearchSortedKernel
:
public
framework
::
OpKernel
<
T
>
{
void
SearchsortedKernel
(
const
Context
&
ctx
,
public:
const
DenseTensor
&
sorted_sequence
,
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
DenseTensor
&
value
,
auto
*
sorted_sequence
=
context
.
Input
<
Tensor
>
(
"SortedSequence"
);
bool
out_int32
,
auto
*
value
=
context
.
Input
<
Tensor
>
(
"Values"
);
bool
right
,
bool
out_int32
=
context
.
Attr
<
bool
>
(
"out_int32"
);
DenseTensor
*
out
)
{
bool
right
=
context
.
Attr
<
bool
>
(
"right"
);
if
(
out_int32
)
{
auto
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
ctx
.
template
Alloc
<
int
>(
out
);
int
*
out_data
=
out
->
data
<
int
>
();
if
(
out_int32
)
{
SearchSortedFunctor
<
Context
,
T
,
int
>
functor
(
int
*
out_data
=
out
->
mutable_data
<
int
>
(
context
.
GetPlace
());
ctx
,
&
sorted_sequence
,
&
value
,
right
,
out_data
);
SearchSortedFunctor
<
DeviceContext
,
T
,
int
>
functor
(
VisitDataType
(
value
.
dtype
(),
functor
);
context
,
sorted_sequence
,
value
,
right
,
out_data
);
}
else
{
VisitDataType
(
framework
::
TransToProtoVarType
(
value
->
dtype
()),
functor
);
ctx
.
template
Alloc
<
int64_t
>(
out
);
}
else
{
int64_t
*
out_data
=
out
->
data
<
int64_t
>
();
int64_t
*
out_data
=
out
->
mutable_data
<
int64_t
>
(
context
.
GetPlace
());
SearchSortedFunctor
<
Context
,
T
,
int64_t
>
functor
(
SearchSortedFunctor
<
DeviceContext
,
T
,
int64_t
>
functor
(
ctx
,
&
sorted_sequence
,
&
value
,
right
,
out_data
);
context
,
sorted_sequence
,
value
,
right
,
out_data
);
VisitDataType
(
value
.
dtype
(),
functor
);
VisitDataType
(
framework
::
TransToProtoVarType
(
value
->
dtype
()),
functor
);
}
}
}
}
;
}
}
// namespace operators
}
// namespace phi
}
// namespace paddle
paddle/
fluid/operators/searchsorted_op.cu
→
paddle/
phi/kernels/searchsorted_kernel.h
浏览文件 @
85f8fd9b
// Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -12,12 +12,18 @@
...
@@ -12,12 +12,18 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/searchsorted_op.h"
#pragma once
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
#include "paddle/phi/core/dense_tensor.h"
searchsorted
,
ops
::
SearchSortedKernel
<
plat
::
CUDADeviceContext
,
float
>
,
ops
::
SearchSortedKernel
<
plat
::
CUDADeviceContext
,
double
>
,
namespace
phi
{
ops
::
SearchSortedKernel
<
plat
::
CUDADeviceContext
,
int
>
,
ops
::
SearchSortedKernel
<
plat
::
CUDADeviceContext
,
int64_t
>
);
template
<
typename
T
,
typename
Context
>
void
SearchsortedKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
sorted_sequence
,
const
DenseTensor
&
value
,
bool
out_int32
,
bool
right
,
DenseTensor
*
out
);
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录