Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
f25dba0a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f25dba0a
编写于
3月 10, 2022
作者:
Z
Zhong Hui
提交者:
GitHub
3月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PHI] Move arg min max to PHI. (#40222)
* move arg min max to phi. * move infermeta. * fix as reviews.
上级
1128db30
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
623 addition
and
232 deletion
+623
-232
paddle/fluid/operators/arg_max_op.cc
paddle/fluid/operators/arg_max_op.cc
+10
-14
paddle/fluid/operators/arg_min_max_op_base.h
paddle/fluid/operators/arg_min_max_op_base.h
+0
-184
paddle/fluid/operators/arg_min_op.cc
paddle/fluid/operators/arg_min_op.cc
+8
-13
paddle/fluid/operators/arg_min_op.cu
paddle/fluid/operators/arg_min_op.cu
+0
-21
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+77
-0
paddle/phi/infermeta/unary.h
paddle/phi/infermeta/unary.h
+8
-0
paddle/phi/kernels/arg_min_max_kernel.h
paddle/phi/kernels/arg_min_max_kernel.h
+39
-0
paddle/phi/kernels/cpu/arg_min_max_kernel.cc
paddle/phi/kernels/cpu/arg_min_max_kernel.cc
+203
-0
paddle/phi/kernels/gpu/arg_min_max_kernel.cu
paddle/phi/kernels/gpu/arg_min_max_kernel.cu
+278
-0
未找到文件。
paddle/fluid/operators/arg_max_op.cc
浏览文件 @
f25dba0a
...
@@ -15,23 +15,19 @@ limitations under the License. */
...
@@ -15,23 +15,19 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
DECLARE_INFER_SHAPE_FUNCTOR
(
arg_max
,
ArgMaxInferShapeFunctor
,
PD_INFER_META
(
phi
::
ArgMinMaxInferMeta
));
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
arg_max
,
paddle
::
operators
::
ArgMinMaxOp
,
paddle
::
operators
::
ArgMaxOpMaker
,
arg_max
,
paddle
::
operators
::
ArgMinMaxOp
,
paddle
::
operators
::
ArgMaxOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
ArgMaxInferShapeFunctor
);
REGISTER_OP_CPU_KERNEL
(
arg_max
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int32_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int16_t
>
,
paddle
::
operators
::
ArgMaxKernel
<
paddle
::
platform
::
CPUDeviceContext
,
uint8_t
>
);
REGISTER_OP_VERSION
(
arg_max
)
REGISTER_OP_VERSION
(
arg_max
)
.
AddCheckpoint
(
.
AddCheckpoint
(
R"ROC(
R"ROC(
...
...
paddle/fluid/operators/arg_min_max_op_base.h
浏览文件 @
f25dba0a
...
@@ -27,193 +27,9 @@ limitations under the License. */
...
@@ -27,193 +27,9 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
enum
ArgMinMaxType
{
kArgMin
,
kArgMax
};
template
<
typename
DeviceContext
,
typename
T
,
typename
Tout
,
int64_t
Rank
,
ArgMinMaxType
argMinMaxValue
>
struct
ArgMinMaxFunctor
{};
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename DeviceContext, typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<DeviceContext, T, Tout, Rank, \
enum_argminmax_value> { \
void operator()(const DeviceContext& ctx, const framework::LoDTensor& in, \
framework::LoDTensor* out, framework::DDim x_dims, \
int64_t axis, bool keepdims) { \
auto in_eigen = framework::EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \
auto out_eigen = framework::EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} else { \
auto out_eigen = framework::EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(*(ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} \
}
DECLARE_ARG_MIN_MAX_FUNCTOR
(
argmin
,
ArgMinMaxType
::
kArgMin
);
DECLARE_ARG_MIN_MAX_FUNCTOR
(
argmax
,
ArgMinMaxType
::
kArgMax
);
template
<
typename
DeviceContext
,
typename
T
,
ArgMinMaxType
EnumArgMinMaxValue
>
struct
VisitDataArgMinMaxFunctor
{
const
framework
::
ExecutionContext
&
ctx
;
explicit
VisitDataArgMinMaxFunctor
(
const
framework
::
ExecutionContext
&
ctx
)
:
ctx
(
ctx
)
{}
template
<
typename
Tout
>
void
apply
()
const
{
auto
&
x
=
*
(
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
));
auto
&
out
=
*
(
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
));
out
.
template
mutable_data
<
Tout
>(
ctx
.
GetPlace
());
auto
axis
=
ctx
.
Attr
<
int64_t
>
(
"axis"
);
auto
keepdims
=
ctx
.
Attr
<
bool
>
(
"keepdims"
);
const
bool
&
flatten
=
ctx
.
Attr
<
bool
>
(
"flatten"
);
// paddle do not have the scalar tensor, just return the shape [1] tensor
if
(
flatten
)
keepdims
=
true
;
// if flatten, will construct the new dims for the cacluate
framework
::
DDim
x_dims
;
if
(
flatten
)
{
x_dims
=
phi
::
make_ddim
({
x
.
numel
()});
// if flatten, the axis just as 0
axis
=
0
;
}
else
{
x_dims
=
x
.
dims
();
if
(
axis
<
0
)
axis
+=
x_dims
.
size
();
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<DeviceContext, T, Tout, rank, EnumArgMinMaxValue> \
functor##rank; \
functor##rank(dev_ctx, x, &out, x_dims, axis, keepdims)
switch
(
x_dims
.
size
())
{
case
1
:
CALL_ARG_MINMAX_FUNCTOR
(
1
);
break
;
case
2
:
CALL_ARG_MINMAX_FUNCTOR
(
2
);
break
;
case
3
:
CALL_ARG_MINMAX_FUNCTOR
(
3
);
break
;
case
4
:
CALL_ARG_MINMAX_FUNCTOR
(
4
);
break
;
case
5
:
CALL_ARG_MINMAX_FUNCTOR
(
5
);
break
;
case
6
:
CALL_ARG_MINMAX_FUNCTOR
(
6
);
break
;
default:
PADDLE_ENFORCE_LE
(
x_dims
.
size
(),
6
,
platform
::
errors
::
InvalidArgument
(
"%s operator doesn't supports tensors whose ranks are greater "
"than 6."
,
(
EnumArgMinMaxValue
==
kArgMin
?
"argmin"
:
"argmax"
)));
break
;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}
};
template
<
typename
DeviceContext
,
typename
T
,
ArgMinMaxType
EnumArgMinMaxValue
>
class
ArgMinMaxKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
dtype
=
ctx
.
Attr
<
int
>
(
"dtype"
);
if
(
dtype
<
0
)
{
framework
::
VisitDataTypeTiny
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
framework
::
proto
::
VarType
::
INT64
),
VisitDataArgMinMaxFunctor
<
DeviceContext
,
T
,
EnumArgMinMaxValue
>
(
ctx
));
return
;
}
framework
::
VisitDataTypeTiny
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
dtype
),
VisitDataArgMinMaxFunctor
<
DeviceContext
,
T
,
EnumArgMinMaxValue
>
(
ctx
));
}
};
template
<
typename
DeviceContext
,
typename
T
>
using
ArgMinKernel
=
ArgMinMaxKernel
<
DeviceContext
,
T
,
ArgMinMaxType
::
kArgMin
>
;
template
<
typename
DeviceContext
,
typename
T
>
using
ArgMaxKernel
=
ArgMinMaxKernel
<
DeviceContext
,
T
,
ArgMinMaxType
::
kArgMax
>
;
class
ArgMinMaxOp
:
public
framework
::
OperatorWithKernel
{
class
ArgMinMaxOp
:
public
framework
::
OperatorWithKernel
{
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"arg_min_max"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"arg_min_max"
);
const
auto
&
x_dims
=
ctx
->
GetInputDim
(
"X"
);
int64_t
axis
=
ctx
->
Attrs
().
Get
<
int64_t
>
(
"axis"
);
bool
keepdims
=
ctx
->
Attrs
().
Get
<
bool
>
(
"keepdims"
);
const
bool
&
flatten
=
ctx
->
Attrs
().
Get
<
bool
>
(
"flatten"
);
PADDLE_ENFORCE_GE
(
axis
,
-
x_dims
.
size
(),
platform
::
errors
::
InvalidArgument
(
"'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d)."
,
axis
,
-
x_dims
.
size
()));
PADDLE_ENFORCE_LT
(
axis
,
x_dims
.
size
(),
platform
::
errors
::
InvalidArgument
(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X)."
,
axis
,
x_dims
.
size
()));
const
int
&
dtype
=
ctx
->
Attrs
().
Get
<
int
>
(
"dtype"
);
PADDLE_ENFORCE_EQ
(
(
dtype
<
0
||
dtype
==
2
||
dtype
==
3
),
true
,
platform
::
errors
::
InvalidArgument
(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]"
,
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
),
paddle
::
framework
::
DataTypeToString
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
dtype
))));
auto
x_rank
=
x_dims
.
size
();
if
(
axis
<
0
)
axis
+=
x_rank
;
if
(
ctx
->
IsRuntime
())
{
if
(
dtype
==
framework
::
proto
::
VarType
::
INT32
)
{
int64_t
all_element_num
=
0
;
if
(
flatten
)
{
all_element_num
=
phi
::
product
(
x_dims
);
}
else
{
all_element_num
=
x_dims
[
axis
];
}
PADDLE_ENFORCE_LE
(
all_element_num
,
INT_MAX
,
platform
::
errors
::
InvalidArgument
(
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'."
,
all_element_num
,
INT_MAX
));
}
}
std
::
vector
<
int64_t
>
vec
;
if
(
flatten
)
{
vec
.
emplace_back
(
static_cast
<
int64_t
>
(
1
));
}
else
{
for
(
int64_t
i
=
0
;
i
<
axis
;
i
++
)
vec
.
emplace_back
(
x_dims
[
i
]);
if
(
keepdims
)
{
vec
.
emplace_back
(
static_cast
<
int64_t
>
(
1
));
}
for
(
int64_t
i
=
axis
+
1
;
i
<
x_rank
;
i
++
)
vec
.
emplace_back
(
x_dims
[
i
]);
}
ctx
->
SetOutputDim
(
"Out"
,
phi
::
make_ddim
(
vec
));
}
};
};
class
BaseArgMinMaxOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
BaseArgMinMaxOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
paddle/fluid/operators/arg_min_op.cc
浏览文件 @
f25dba0a
...
@@ -12,26 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,26 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
#include "paddle/fluid/operators/arg_min_max_op_base.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
DECLARE_INFER_SHAPE_FUNCTOR
(
arg_min
,
ArgMinInferShapeFunctor
,
PD_INFER_META
(
phi
::
ArgMinMaxInferMeta
));
REGISTER_OPERATOR
(
REGISTER_OPERATOR
(
arg_min
,
paddle
::
operators
::
ArgMinMaxOp
,
paddle
::
operators
::
ArgMinOpMaker
,
arg_min
,
paddle
::
operators
::
ArgMinMaxOp
,
paddle
::
operators
::
ArgMinOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
ArgMinInferShapeFunctor
);
REGISTER_OP_CPU_KERNEL
(
arg_min
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int32_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int16_t
>
,
paddle
::
operators
::
ArgMinKernel
<
paddle
::
platform
::
CPUDeviceContext
,
uint8_t
>
);
REGISTER_OP_VERSION
(
arg_min
)
REGISTER_OP_VERSION
(
arg_min
)
.
AddCheckpoint
(
.
AddCheckpoint
(
R"ROC(
R"ROC(
...
...
paddle/fluid/operators/arg_min_op.cu
已删除
100644 → 0
浏览文件 @
1128db30
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/arg_min_max_op_base.cu.h"
REGISTER_OP_CUDA_KERNEL
(
arg_min
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
float
,
cub
::
ArgMin
>
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
double
,
cub
::
ArgMin
>
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
int64_t
,
cub
::
ArgMin
>
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
int32_t
,
cub
::
ArgMin
>
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
int8_t
,
cub
::
ArgMin
>
);
paddle/phi/infermeta/unary.cc
浏览文件 @
f25dba0a
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm>
#include <algorithm>
#include <set>
#include <set>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/enforce.h"
...
@@ -1014,6 +1015,82 @@ void DiagInferMeta(const MetaTensor& x,
...
@@ -1014,6 +1015,82 @@ void DiagInferMeta(const MetaTensor& x,
}
}
}
}
void
ArgMinMaxInferMeta
(
const
MetaTensor
&
x
,
int64_t
axis
,
bool
keepdims
,
bool
flatten
,
int
dtype
,
MetaTensor
*
out
,
MetaConfig
config
)
{
const
auto
&
x_dims
=
x
.
dims
();
PADDLE_ENFORCE_GE
(
axis
,
-
x_dims
.
size
(),
phi
::
errors
::
InvalidArgument
(
"'axis'(%d) must be greater than or equal to"
" -Rank(X)(%d)."
,
axis
,
-
x_dims
.
size
()));
PADDLE_ENFORCE_LT
(
axis
,
x_dims
.
size
(),
phi
::
errors
::
InvalidArgument
(
"'axis'(%d) must be less than Rank(X)(%d) of Input(X)."
,
axis
,
x_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
(
dtype
<
0
||
dtype
==
2
||
dtype
==
3
),
true
,
phi
::
errors
::
InvalidArgument
(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]"
,
paddle
::
framework
::
DataTypeToString
(
paddle
::
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
paddle
::
framework
::
proto
::
VarType
::
INT64
),
paddle
::
framework
::
DataTypeToString
(
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
dtype
))));
auto
x_rank
=
x_dims
.
size
();
if
(
axis
<
0
)
axis
+=
x_rank
;
if
(
config
.
is_runtime
)
{
if
(
dtype
==
paddle
::
framework
::
proto
::
VarType
::
INT32
)
{
int64_t
all_element_num
=
0
;
if
(
flatten
)
{
all_element_num
=
phi
::
product
(
x_dims
);
}
else
{
all_element_num
=
x_dims
[
axis
];
}
PADDLE_ENFORCE_LE
(
all_element_num
,
INT_MAX
,
phi
::
errors
::
InvalidArgument
(
"The element num of the argmin/argmax input at axis is "
"%d, is larger than int32 maximum value:%d, you must "
"set the dtype of argmin/argmax to 'int64'."
,
all_element_num
,
INT_MAX
));
}
}
std
::
vector
<
int64_t
>
vec
;
if
(
flatten
)
{
vec
.
emplace_back
(
static_cast
<
int64_t
>
(
1
));
}
else
{
for
(
int64_t
i
=
0
;
i
<
axis
;
i
++
)
vec
.
emplace_back
(
x_dims
[
i
]);
if
(
keepdims
)
{
vec
.
emplace_back
(
static_cast
<
int64_t
>
(
1
));
}
for
(
int64_t
i
=
axis
+
1
;
i
<
x_rank
;
i
++
)
vec
.
emplace_back
(
x_dims
[
i
]);
}
out
->
set_dims
(
phi
::
make_ddim
(
vec
));
if
(
dtype
==
2
)
{
out
->
set_dtype
(
DataType
::
INT32
);
}
else
if
(
dtype
==
3
)
{
out
->
set_dtype
(
DataType
::
INT64
);
}
}
void
SizeInferMeta
(
const
MetaTensor
&
input
,
MetaTensor
*
out
)
{
void
SizeInferMeta
(
const
MetaTensor
&
input
,
MetaTensor
*
out
)
{
out
->
set_dtype
(
DataType
::
INT64
);
out
->
set_dtype
(
DataType
::
INT64
);
out
->
set_dims
({
1
});
out
->
set_dims
({
1
});
...
...
paddle/phi/infermeta/unary.h
浏览文件 @
f25dba0a
...
@@ -147,6 +147,14 @@ void DiagInferMeta(const MetaTensor& x,
...
@@ -147,6 +147,14 @@ void DiagInferMeta(const MetaTensor& x,
float
padding_value
,
float
padding_value
,
MetaTensor
*
out
);
MetaTensor
*
out
);
void
ArgMinMaxInferMeta
(
const
MetaTensor
&
x
,
int64_t
axis
,
bool
keepdims
,
bool
flatten
,
int
dtype
,
MetaTensor
*
out
,
MetaConfig
config
=
MetaConfig
());
void
SizeInferMeta
(
const
MetaTensor
&
input
,
MetaTensor
*
out
);
void
SizeInferMeta
(
const
MetaTensor
&
input
,
MetaTensor
*
out
);
void
DiagonalInferMeta
(
void
DiagonalInferMeta
(
...
...
paddle/
fluid/operators/arg_max_op.cu
→
paddle/
phi/kernels/arg_min_max_kernel.h
浏览文件 @
f25dba0a
/* Copyright (c) 20
18 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 20
22 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
...
@@ -12,11 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,11 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#
include "paddle/fluid/operators/arg_min_max_op_base.cu.h"
#
pragma once
REGISTER_OP_CUDA_KERNEL
(
#include "paddle/phi/core/dense_tensor.h"
arg_max
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
float
,
cub
::
ArgMax
>
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
double
,
cub
::
ArgMax
>
,
namespace
phi
{
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
int64_t
,
cub
::
ArgMax
>
,
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
int32_t
,
cub
::
ArgMax
>
,
template
<
typename
T
,
typename
Context
>
paddle
::
operators
::
ArgMinMaxOpCUDAKernel
<
int8_t
,
cub
::
ArgMax
>
);
void
ArgMinKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int64_t
axis
,
bool
keepdims
,
bool
flatten
,
int
dtype
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
ArgMaxKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int64_t
axis
,
bool
keepdims
,
bool
flatten
,
int
dtype
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/cpu/arg_min_max_kernel.cc
0 → 100644
浏览文件 @
f25dba0a
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/arg_min_max_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
phi
{
enum
ArgMinMaxType
{
kArgMin
,
kArgMax
};
template
<
typename
Context
,
typename
T
,
typename
Tout
,
int64_t
Rank
,
ArgMinMaxType
argMinMaxValue
>
struct
ArgMinMaxFunctor
{};
#define DECLARE_ARG_MIN_MAX_FUNCTOR(eigen_op_type, enum_argminmax_value) \
template <typename Context, typename T, typename Tout, int64_t Rank> \
struct ArgMinMaxFunctor<Context, T, Tout, Rank, enum_argminmax_value> { \
void operator()(const Context& dev_ctx, \
const DenseTensor& in, \
DenseTensor* out, \
phi::DDim x_dims, \
int64_t axis, \
bool keepdims) { \
auto in_eigen = EigenTensor<T, Rank>::From(in, x_dims); \
if (keepdims) { \
auto out_eigen = EigenTensor<Tout, Rank>::From(*out); \
out_eigen.device(*(dev_ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} else { \
auto out_eigen = EigenTensor<Tout, Rank - 1>::From(*out); \
out_eigen.device(*(dev_ctx.eigen_device())) = \
in_eigen.eigen_op_type(axis).template cast<Tout>(); \
} \
} \
}
DECLARE_ARG_MIN_MAX_FUNCTOR
(
argmin
,
ArgMinMaxType
::
kArgMin
);
DECLARE_ARG_MIN_MAX_FUNCTOR
(
argmax
,
ArgMinMaxType
::
kArgMax
);
template
<
typename
Context
,
typename
T
,
ArgMinMaxType
EnumArgMinMaxValue
>
struct
VisitDataArgMinMaxFunctor
{
const
Context
&
dev_ctx
;
const
DenseTensor
&
x
;
int64_t
axis
;
bool
keepdims
;
bool
flatten
;
DenseTensor
*
out
;
explicit
VisitDataArgMinMaxFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int64_t
axis
,
bool
keepdims
,
bool
flatten
,
DenseTensor
*
out
)
:
dev_ctx
(
dev_ctx
),
x
(
x
),
axis
(
axis
),
keepdims
(
keepdims
),
flatten
(
flatten
),
out
(
out
)
{}
template
<
typename
Tout
>
void
apply
()
const
{
dev_ctx
.
template
Alloc
<
Tout
>(
out
);
bool
new_keepdims
=
keepdims
;
if
(
flatten
)
new_keepdims
=
true
;
// if flatten, will construct the new dims for the cacluate
phi
::
DDim
x_dims
;
int
new_axis
=
axis
;
if
(
flatten
)
{
x_dims
=
phi
::
make_ddim
({
x
.
numel
()});
// if flatten, the axis just as 0
new_axis
=
0
;
}
else
{
x_dims
=
x
.
dims
();
if
(
axis
<
0
)
new_axis
=
axis
+
x_dims
.
size
();
}
#define CALL_ARG_MINMAX_FUNCTOR(rank) \
ArgMinMaxFunctor<Context, T, Tout, rank, EnumArgMinMaxValue> functor##rank; \
functor##rank(dev_ctx, x, out, x_dims, new_axis, new_keepdims)
switch
(
x_dims
.
size
())
{
case
1
:
CALL_ARG_MINMAX_FUNCTOR
(
1
);
break
;
case
2
:
CALL_ARG_MINMAX_FUNCTOR
(
2
);
break
;
case
3
:
CALL_ARG_MINMAX_FUNCTOR
(
3
);
break
;
case
4
:
CALL_ARG_MINMAX_FUNCTOR
(
4
);
break
;
case
5
:
CALL_ARG_MINMAX_FUNCTOR
(
5
);
break
;
case
6
:
CALL_ARG_MINMAX_FUNCTOR
(
6
);
break
;
default:
PADDLE_ENFORCE_LE
(
x_dims
.
size
(),
6
,
phi
::
errors
::
InvalidArgument
(
"%s operator doesn't supports tensors whose ranks are greater "
"than 6."
,
(
EnumArgMinMaxValue
==
kArgMin
?
"argmin"
:
"argmax"
)));
break
;
#undef CALL_ARG_MINMAX_FUNCTOR
}
}
};
template
<
typename
Context
,
typename
T
,
ArgMinMaxType
EnumArgMinMaxValue
>
void
ArgMinMaxKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int64_t
axis
,
bool
keepdims
,
bool
flatten
,
int
dtype
,
DenseTensor
*
out
)
{
if
(
dtype
<
0
)
{
paddle
::
framework
::
VisitDataTypeTiny
(
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
paddle
::
framework
::
proto
::
VarType
::
INT64
),
VisitDataArgMinMaxFunctor
<
Context
,
T
,
EnumArgMinMaxValue
>
(
dev_ctx
,
x
,
axis
,
keepdims
,
flatten
,
out
));
return
;
}
paddle
::
framework
::
VisitDataTypeTiny
(
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
dtype
),
VisitDataArgMinMaxFunctor
<
Context
,
T
,
EnumArgMinMaxValue
>
(
dev_ctx
,
x
,
axis
,
keepdims
,
flatten
,
out
));
}
template
<
typename
T
,
typename
Context
>
void
ArgMinKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int64_t
axis
,
bool
keepdims
,
bool
flatten
,
int
dtype
,
DenseTensor
*
out
)
{
ArgMinMaxKernel
<
Context
,
T
,
ArgMinMaxType
::
kArgMin
>
(
dev_ctx
,
x
,
axis
,
keepdims
,
flatten
,
dtype
,
out
);
}
template
<
typename
T
,
typename
Context
>
void
ArgMaxKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int64_t
axis
,
bool
keepdims
,
bool
flatten
,
int
dtype
,
DenseTensor
*
out
)
{
ArgMinMaxKernel
<
Context
,
T
,
ArgMinMaxType
::
kArgMax
>
(
dev_ctx
,
x
,
axis
,
keepdims
,
flatten
,
dtype
,
out
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
arg_min
,
CPU
,
ALL_LAYOUT
,
phi
::
ArgMinKernel
,
float
,
double
,
int32_t
,
int64_t
,
int16_t
,
uint8_t
)
{}
PD_REGISTER_KERNEL
(
arg_max
,
CPU
,
ALL_LAYOUT
,
phi
::
ArgMaxKernel
,
float
,
double
,
int32_t
,
int64_t
,
int16_t
,
uint8_t
)
{}
paddle/
fluid/operators/arg_min_max_op_base.cu.h
→
paddle/
phi/kernels/gpu/arg_min_max_kernel.cu
浏览文件 @
f25dba0a
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
Licensed under the Apache License, Version 2.0 (the "License");
#include "paddle/phi/kernels/arg_min_max_kernel.h"
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#if defined(__NVCC__) || defined(__HIPCC__)
#if defined(__NVCC__) || defined(__HIPCC__)
...
@@ -24,21 +27,14 @@ limitations under the License. */
...
@@ -24,21 +27,14 @@ limitations under the License. */
namespace
cub
=
hipcub
;
namespace
cub
=
hipcub
;
#endif
#endif
#include <limits>
#include <limits>
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include <typeinfo>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/ddim.h"
namespace
paddle
{
namespace
phi
{
namespace
operators
{
namespace
{
// NOLINT
namespace
{
// NOLINT
template
<
typename
K
,
typename
V
>
template
<
typename
K
,
typename
V
>
using
KeyValuePair
=
cub
::
KeyValuePair
<
K
,
V
>
;
using
KeyValuePair
=
cub
::
KeyValuePair
<
K
,
V
>
;
using
Tensor
=
framework
::
Tensor
;
}
// end namespace
}
// end namespace
...
@@ -62,7 +58,9 @@ template <typename T, typename IndType, class Reducer, size_t BlockDim>
...
@@ -62,7 +58,9 @@ template <typename T, typename IndType, class Reducer, size_t BlockDim>
__global__
void
ArgCUDAKernel
(
const
int64_t
height
,
// n * h
__global__
void
ArgCUDAKernel
(
const
int64_t
height
,
// n * h
const
int64_t
width
,
// c
const
int64_t
width
,
// c
const
int64_t
post_size
,
// h
const
int64_t
post_size
,
// h
const
Reducer
reducer
,
const
T
init
,
const
T
*
in
,
const
Reducer
reducer
,
const
T
init
,
const
T
*
in
,
IndType
*
out
)
{
IndType
*
out
)
{
typedef
cub
::
BlockReduce
<
KeyValuePair
<
int
,
T
>
,
BlockDim
>
BlockReduce
;
typedef
cub
::
BlockReduce
<
KeyValuePair
<
int
,
T
>
,
BlockDim
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
...
@@ -84,10 +82,13 @@ __global__ void ArgCUDAKernel(const int64_t height, // n * h
...
@@ -84,10 +82,13 @@ __global__ void ArgCUDAKernel(const int64_t height, // n * h
}
}
template
<
typename
T
,
typename
IndType
,
class
Reducer
>
template
<
typename
T
,
typename
IndType
,
class
Reducer
>
void
ComputeFullArg
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
&
input
,
void
ComputeFullArg
(
const
phi
::
GPUContext
&
dev_ctx
,
Tensor
*
indices
,
const
int64_t
pre
,
const
int64_t
post
,
const
DenseTensor
&
input
,
DenseTensor
*
indices
,
const
int64_t
pre
,
const
int64_t
post
,
const
int64_t
n
)
{
const
int64_t
n
)
{
auto
cu_stream
=
ctx
.
stream
();
auto
cu_stream
=
dev_ctx
.
stream
();
auto
ComputeBlockSize
=
[](
int64_t
col
)
{
auto
ComputeBlockSize
=
[](
int64_t
col
)
{
auto
block_size
=
8
;
auto
block_size
=
8
;
if
(
col
>
512
)
if
(
col
>
512
)
...
@@ -110,93 +111,168 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
...
@@ -110,93 +111,168 @@ void ComputeFullArg(const platform::CUDADeviceContext& ctx, const Tensor& input,
return
block_size
;
return
block_size
;
};
};
int64_t
max_grid_dimx
=
ctx
.
GetCUDAMaxGridDimSize
()[
0
];
int64_t
max_grid_dimx
=
dev_ctx
.
GetCUDAMaxGridDimSize
()[
0
];
int64_t
height
=
pre
*
post
;
int64_t
height
=
pre
*
post
;
int64_t
width
=
n
;
int64_t
width
=
n
;
int64_t
grid_size
=
height
<
max_grid_dimx
?
height
:
max_grid_dimx
;
int64_t
grid_size
=
height
<
max_grid_dimx
?
height
:
max_grid_dimx
;
const
T
*
in_data
=
input
.
data
<
T
>
();
const
T
*
in_data
=
input
.
data
<
T
>
();
IndType
*
out_data
=
indices
->
mutable_data
<
IndType
>
(
ctx
.
GetPlace
());
IndType
*
out_data
=
dev_ctx
.
template
Alloc
<
IndType
>(
indices
);
if
(
typeid
(
Reducer
)
==
typeid
(
cub
::
ArgMax
))
{
if
(
typeid
(
Reducer
)
==
typeid
(
cub
::
ArgMax
))
{
switch
(
ComputeBlockSize
(
width
))
{
switch
(
ComputeBlockSize
(
width
))
{
FIXED_BLOCK_DIM_CASE
(
FIXED_BLOCK_DIM_CASE
(
ArgCUDAKernel
<
T
,
IndType
,
Reducer
,
ArgCUDAKernel
<
T
,
IndType
,
Reducer
,
kBlockDim
><<<
grid_size
,
kBlockDim
,
0
,
cu_stream
>>>
(
kBlockDim
><<<
grid_size
,
kBlockDim
,
0
,
cu_stream
>>>
(
height
,
width
,
post
,
Reducer
(),
std
::
numeric_limits
<
T
>::
lowest
(),
height
,
in_data
,
out_data
));
width
,
post
,
Reducer
(),
std
::
numeric_limits
<
T
>::
lowest
(),
in_data
,
out_data
));
}
}
}
else
{
}
else
{
switch
(
ComputeBlockSize
(
width
))
{
switch
(
ComputeBlockSize
(
width
))
{
FIXED_BLOCK_DIM_CASE
(
FIXED_BLOCK_DIM_CASE
(
ArgCUDAKernel
<
T
,
IndType
,
Reducer
,
ArgCUDAKernel
<
T
,
IndType
,
Reducer
,
kBlockDim
><<<
grid_size
,
kBlockDim
,
0
,
cu_stream
>>>
(
kBlockDim
><<<
grid_size
,
kBlockDim
,
0
,
cu_stream
>>>
(
height
,
width
,
post
,
Reducer
(),
std
::
numeric_limits
<
T
>::
max
(),
height
,
in_data
,
out_data
));
width
,
post
,
Reducer
(),
std
::
numeric_limits
<
T
>::
max
(),
in_data
,
out_data
));
}
}
}
}
}
}
template
<
typename
T
,
class
Reducer
>
template
<
typename
Context
,
typename
T
,
class
Reducer
>
struct
VisitDataCudaArgMinMaxFunctor
{
struct
VisitDataCudaArgMinMaxFunctor
{
const
framework
::
ExecutionContext
&
ctx
;
const
Context
&
dev_ctx
;
const
DenseTensor
&
x
;
int64_t
axis
;
bool
keepdims
;
bool
flatten
;
DenseTensor
*
out
;
explicit
VisitDataCudaArgMinMaxFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int64_t
axis
,
bool
keepdims
,
bool
flatten
,
DenseTensor
*
out
)
:
dev_ctx
(
dev_ctx
),
x
(
x
),
axis
(
axis
),
keepdims
(
keepdims
),
flatten
(
flatten
),
out
(
out
)
{}
explicit
VisitDataCudaArgMinMaxFunctor
(
const
framework
::
ExecutionContext
&
ctx
)
:
ctx
(
ctx
)
{}
template
<
typename
IndType
>
template
<
typename
IndType
>
void
apply
()
const
{
void
apply
()
const
{
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
phi
::
DDim
x_dims
;
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
int
new_axis
=
axis
;
int
axis
=
ctx
.
Attr
<
int64_t
>
(
"axis"
);
const
bool
&
flatten
=
ctx
.
Attr
<
bool
>
(
"flatten"
);
framework
::
DDim
input_dims
;
if
(
flatten
)
{
if
(
flatten
)
{
input_dims
=
phi
::
make_ddim
({
input
->
numel
()});
x_dims
=
phi
::
make_ddim
({
x
.
numel
()});
// if flatten, the axis just as 0
// if flatten, the axis just as 0
axis
=
0
;
new_axis
=
0
;
}
else
{
}
else
{
input_dims
=
input
->
dims
();
x_dims
=
x
.
dims
();
if
(
axis
<
0
)
axis
+=
input
->
dims
().
size
();
if
(
axis
<
0
)
new_axis
=
axis
+
x
.
dims
().
size
();
}
}
int64_t
numel
=
input
->
numel
();
int64_t
numel
=
x
.
numel
();
int64_t
groups
=
numel
/
input_dims
[
axis
];
int64_t
groups
=
numel
/
x_dims
[
new_axis
];
int64_t
pre
=
1
;
int64_t
pre
=
1
;
int64_t
post
=
1
;
int64_t
post
=
1
;
int64_t
n
=
input_dims
[
axis
];
int64_t
n
=
x_dims
[
new_axis
];
for
(
int
i
=
0
;
i
<
axis
;
i
++
)
{
for
(
int
i
=
0
;
i
<
new_axis
;
i
++
)
{
pre
*=
input_dims
[
i
];
pre
*=
x_dims
[
i
];
}
}
for
(
int
i
=
axis
+
1
;
i
<
input_dims
.
size
();
i
++
)
{
for
(
int
i
=
new_axis
+
1
;
i
<
x_dims
.
size
();
i
++
)
{
post
*=
input_dims
[
i
];
post
*=
x_dims
[
i
];
}
}
const
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
ComputeFullArg
<
T
,
IndType
,
Reducer
>
(
dev_ctx
,
x
,
out
,
pre
,
post
,
n
);
ComputeFullArg
<
T
,
IndType
,
Reducer
>
(
dev_ctx
,
*
input
,
output
,
pre
,
post
,
n
);
}
}
};
};
template
<
typename
T
,
class
Reducer
>
class
ArgMinMaxOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
template
<
typename
Context
,
typename
T
,
class
Reducer
>
public:
void
ArgMinMaxOpCUDAKernel
(
const
Context
&
dev_ctx
,
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
DenseTensor
&
x
,
auto
&
dtype
=
ctx
.
Attr
<
int
>
(
"dtype"
);
int64_t
axis
,
bool
keepdims
,
bool
flatten
,
int
dtype
,
DenseTensor
*
out
)
{
if
(
dtype
<
0
)
{
if
(
dtype
<
0
)
{
framework
::
VisitDataTypeTiny
(
paddle
::
framework
::
VisitDataTypeTiny
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
framework
::
proto
::
VarType
::
INT64
),
paddle
::
framework
::
proto
::
VarType
::
INT64
),
VisitDataCudaArgMinMaxFunctor
<
T
,
Reducer
>
(
ctx
));
VisitDataCudaArgMinMaxFunctor
<
Context
,
T
,
Reducer
>
(
dev_ctx
,
x
,
axis
,
keepdims
,
flatten
,
out
));
return
;
return
;
}
}
framework
::
VisitDataTypeTiny
(
paddle
::
framework
::
VisitDataTypeTiny
(
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
dtype
),
static_cast
<
paddle
::
framework
::
proto
::
VarType
::
Type
>
(
dtype
),
VisitDataCudaArgMinMaxFunctor
<
T
,
Reducer
>
(
ctx
));
VisitDataCudaArgMinMaxFunctor
<
Context
,
T
,
Reducer
>
(
}
dev_ctx
,
x
,
axis
,
keepdims
,
flatten
,
out
));
};
}
template
<
typename
T
,
typename
Context
>
void
ArgMinKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int64_t
axis
,
bool
keepdims
,
bool
flatten
,
int
dtype
,
DenseTensor
*
out
)
{
ArgMinMaxOpCUDAKernel
<
Context
,
T
,
cub
::
ArgMin
>
(
dev_ctx
,
x
,
axis
,
keepdims
,
flatten
,
dtype
,
out
);
}
template
<
typename
T
,
typename
Context
>
void
ArgMaxKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int64_t
axis
,
bool
keepdims
,
bool
flatten
,
int
dtype
,
DenseTensor
*
out
)
{
ArgMinMaxOpCUDAKernel
<
Context
,
T
,
cub
::
ArgMax
>
(
dev_ctx
,
x
,
axis
,
keepdims
,
flatten
,
dtype
,
out
);
}
#endif
#endif
}
// namespace operators
}
// namespace phi
}
// namespace paddle
PD_REGISTER_KERNEL
(
arg_min
,
GPU
,
ALL_LAYOUT
,
phi
::
ArgMinKernel
,
float
,
double
,
int32_t
,
int64_t
,
int16_t
,
uint8_t
)
{}
PD_REGISTER_KERNEL
(
arg_max
,
GPU
,
ALL_LAYOUT
,
phi
::
ArgMaxKernel
,
float
,
double
,
int32_t
,
int64_t
,
int16_t
,
uint8_t
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录