Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
76f87034
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
76f87034
编写于
3月 12, 2022
作者:
C
Chen Weihang
提交者:
GitHub
3月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi] Move allclose op kernel into phi (#40469)
* move allclose kernel * remove allclose op kernel * fix coverage failed
上级
39de9b8a
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
276 addition
and
131 deletion
+276
-131
paddle/fluid/operators/allclose_op.cc
paddle/fluid/operators/allclose_op.cc
+1
-38
paddle/fluid/operators/allclose_op.h
paddle/fluid/operators/allclose_op.h
+0
-93
paddle/phi/api/lib/utils/tensor_utils.cc
paddle/phi/api/lib/utils/tensor_utils.cc
+7
-0
paddle/phi/kernels/allclose_kernel.h
paddle/phi/kernels/allclose_kernel.h
+31
-0
paddle/phi/kernels/cpu/allclose_kernel.cc
paddle/phi/kernels/cpu/allclose_kernel.cc
+71
-0
paddle/phi/kernels/gpu/allclose_kernel.cu
paddle/phi/kernels/gpu/allclose_kernel.cu
+89
-0
paddle/phi/ops/compat/allclose_sig.cc
paddle/phi/ops/compat/allclose_sig.cc
+49
-0
paddle/phi/tests/ops/test_op_signature.cc
paddle/phi/tests/ops/test_op_signature.cc
+28
-0
未找到文件。
paddle/fluid/operators/allclose_op.cc
浏览文件 @
76f87034
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/operators/allclose_op.h"
#include <cmath>
#include <cmath>
#include <string>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
...
@@ -23,41 +23,6 @@
...
@@ -23,41 +23,6 @@
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
>
struct
GetTensorValue
<
platform
::
CPUDeviceContext
,
T
>
{
T
operator
()(
const
platform
::
CPUDeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
tensor
)
const
{
return
*
(
tensor
.
data
<
T
>
());
}
};
template
<
typename
T
>
struct
AllcloseFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
other
,
const
double
rtol
,
const
double
atol
,
bool
equal_nan
,
framework
::
Tensor
*
output
)
{
auto
*
in_a
=
in
.
data
<
T
>
();
auto
*
in_b
=
other
.
data
<
T
>
();
auto
*
out_data
=
output
->
mutable_data
<
bool
>
(
ctx
.
GetPlace
());
auto
num
=
in
.
numel
();
*
out_data
=
true
;
for
(
int
i
=
0
;
i
<
num
;
i
++
)
{
const
T
a
=
in_a
[
i
],
b
=
in_b
[
i
];
bool
val
;
if
(
std
::
isnan
(
a
)
||
std
::
isnan
(
b
))
{
val
=
equal_nan
&&
std
::
isnan
(
a
)
==
std
::
isnan
(
b
);
}
else
{
T
left
=
(
a
>
b
?
a
-
b
:
b
-
a
);
T
right
=
atol
+
(
b
>
0
?
rtol
*
b
:
(
-
rtol
)
*
b
);
T
diff
=
(
left
>
right
?
left
-
right
:
right
-
left
);
val
=
a
==
b
||
left
<=
right
||
diff
<=
1e-15
;
}
*
out_data
&=
val
;
}
}
};
class
AllcloseOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
AllcloseOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
void
Make
()
override
{
void
Make
()
override
{
...
@@ -157,8 +122,6 @@ REGISTER_OPERATOR(
...
@@ -157,8 +122,6 @@ REGISTER_OPERATOR(
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
AllcloseOpVarTypeInference
);
ops
::
AllcloseOpVarTypeInference
);
REGISTER_OP_CPU_KERNEL
(
allclose
,
ops
::
AllcloseKernel
<
CPU
,
float
>
,
ops
::
AllcloseKernel
<
CPU
,
double
>
);
/* ========================== register checkpoint ===========================*/
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION
(
allclose
)
REGISTER_OP_VERSION
(
allclose
)
...
...
paddle/fluid/operators/allclose_op.h
已删除
100644 → 0
浏览文件 @
39de9b8a
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
>
struct
GetTensorValue
{
T
operator
()(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
tensor
)
const
;
};
template
<
typename
DeviceContext
,
typename
T
>
struct
AllcloseFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
other
,
const
float
rtol
,
const
float
atol
,
bool
equal_nan
,
framework
::
Tensor
*
output
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
AllcloseKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
// get attrs
bool
equal_nan
=
ctx
.
Attr
<
bool
>
(
"equal_nan"
);
// get input/output
const
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
const
auto
*
other
=
ctx
.
Input
<
Tensor
>
(
"Other"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
double
rtol_v
=
std
::
stod
(
ctx
.
Attr
<
std
::
string
>
(
"rtol"
));
double
atol_v
=
std
::
stod
(
ctx
.
Attr
<
std
::
string
>
(
"atol"
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
GetTensorValue
<
DeviceContext
,
double
>
get_tensor_value
;
if
(
ctx
.
HasInput
(
"Rtol"
))
{
const
auto
*
rtol
=
ctx
.
Input
<
Tensor
>
(
"Rtol"
);
PADDLE_ENFORCE_EQ
(
rtol
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Input(Rtol) size must be 1, but get %d."
,
rtol
->
numel
()));
PADDLE_ENFORCE_EQ
(
framework
::
TransToProtoVarType
(
rtol
->
dtype
()),
framework
::
proto
::
VarType
::
FP64
,
platform
::
errors
::
InvalidArgument
(
"Input(Rtol) type must be double, but get %s."
,
framework
::
DataTypeToString
(
framework
::
TransToProtoVarType
(
rtol
->
dtype
()))));
rtol_v
=
get_tensor_value
(
dev_ctx
,
*
rtol
);
}
if
(
ctx
.
HasInput
(
"Atol"
))
{
const
auto
*
atol
=
ctx
.
Input
<
Tensor
>
(
"Atol"
);
PADDLE_ENFORCE_EQ
(
atol
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Input(Atol) size must be 1, but get %d"
,
atol
->
numel
()));
PADDLE_ENFORCE_EQ
(
framework
::
TransToProtoVarType
(
atol
->
dtype
()),
framework
::
proto
::
VarType
::
FP64
,
platform
::
errors
::
InvalidArgument
(
"Input(Atol) type must be double, but get %s"
,
framework
::
DataTypeToString
(
framework
::
TransToProtoVarType
(
atol
->
dtype
()))));
atol_v
=
get_tensor_value
(
dev_ctx
,
*
atol
);
}
AllcloseFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
input
,
*
other
,
rtol_v
,
atol_v
,
equal_nan
,
out
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/phi/api/lib/utils/tensor_utils.cc
浏览文件 @
76f87034
...
@@ -40,6 +40,13 @@ phi::Scalar MakePhiScalarFromVar(const framework::Variable& variable) {
...
@@ -40,6 +40,13 @@ phi::Scalar MakePhiScalarFromVar(const framework::Variable& variable) {
auto
expected_place
=
phi
::
TransToPhiPlace
(
phi
::
Backend
::
CPU
);
auto
expected_place
=
phi
::
TransToPhiPlace
(
phi
::
Backend
::
CPU
);
if
(
variable
.
IsType
<
framework
::
LoDTensor
>
())
{
if
(
variable
.
IsType
<
framework
::
LoDTensor
>
())
{
const
auto
&
tensor
=
variable
.
Get
<
framework
::
LoDTensor
>
();
const
auto
&
tensor
=
variable
.
Get
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
tensor
.
numel
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"The DenseTensor used to construct "
"the Scalar contains more than 1 "
"value, it contains `%d` values."
,
tensor
.
numel
()));
if
(
!
platform
::
is_same_place
(
tensor
.
place
(),
expected_place
))
{
if
(
!
platform
::
is_same_place
(
tensor
.
place
(),
expected_place
))
{
framework
::
LoDTensor
tmp_tensor
;
framework
::
LoDTensor
tmp_tensor
;
framework
::
TensorCopySync
(
tensor
,
expected_place
,
&
tmp_tensor
);
framework
::
TensorCopySync
(
tensor
,
expected_place
,
&
tmp_tensor
);
...
...
paddle/phi/kernels/allclose_kernel.h
0 → 100644
浏览文件 @
76f87034
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
AllCloseKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
Scalar
&
rtol
,
const
Scalar
&
atol
,
bool
equal_nan
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/cpu/allclose_kernel.cc
0 → 100644
浏览文件 @
76f87034
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/allclose_kernel.h"
#include <cmath>
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
AllCloseKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
Scalar
&
rtol
,
const
Scalar
&
atol
,
bool
equal_nan
,
DenseTensor
*
out
)
{
PADDLE_ENFORCE_EQ
(
rtol
.
dtype
(),
DataType
::
FLOAT64
,
phi
::
errors
::
InvalidArgument
(
"Input (Rtol) type must be double, but get %s."
,
rtol
.
dtype
()));
PADDLE_ENFORCE_EQ
(
atol
.
dtype
(),
DataType
::
FLOAT64
,
phi
::
errors
::
InvalidArgument
(
"Input (Atol) type must be double, but get %s."
,
atol
.
dtype
()));
auto
*
in_a
=
x
.
data
<
T
>
();
auto
*
in_b
=
y
.
data
<
T
>
();
auto
rtol_v
=
rtol
.
to
<
double
>
();
auto
atol_v
=
atol
.
to
<
double
>
();
auto
*
out_data
=
dev_ctx
.
template
Alloc
<
bool
>(
out
);
*
out_data
=
true
;
auto
num
=
x
.
numel
();
for
(
int64_t
i
=
0
;
i
<
num
;
++
i
)
{
const
T
a
=
in_a
[
i
],
b
=
in_b
[
i
];
bool
val
;
if
(
std
::
isnan
(
a
)
||
std
::
isnan
(
b
))
{
val
=
equal_nan
&&
std
::
isnan
(
a
)
==
std
::
isnan
(
b
);
}
else
{
T
left
=
(
a
>
b
?
a
-
b
:
b
-
a
);
T
right
=
atol_v
+
(
b
>
0
?
rtol_v
*
b
:
(
-
rtol_v
)
*
b
);
T
diff
=
(
left
>
right
?
left
-
right
:
right
-
left
);
val
=
a
==
b
||
left
<=
right
||
diff
<=
1e-15
;
}
*
out_data
&=
val
;
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
allclose
,
CPU
,
ALL_LAYOUT
,
phi
::
AllCloseKernel
,
float
,
double
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
BOOL
);
}
paddle/
fluid/operators/allclose_op
.cu
→
paddle/
phi/kernels/gpu/allclose_kernel
.cu
浏览文件 @
76f87034
// Copyright (c) 20
18
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 20
22
PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// you may not use this file except in compliance with the License.
...
@@ -12,30 +12,21 @@
...
@@ -12,30 +12,21 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/allclose_kernel.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/allclose_op.h"
namespace
paddle
{
#include "paddle/phi/core/enforce.h"
namespace
operators
{
#include "paddle/phi/core/kernel_registry.h"
template
<
typename
T
>
namespace
phi
{
struct
GetTensorValue
<
platform
::
CUDADeviceContext
,
T
>
{
T
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
tensor
)
const
{
const
T
*
data
=
tensor
.
data
<
T
>
();
T
value
;
const
auto
gpu_place
=
dev_ctx
.
GetPlace
();
memory
::
Copy
(
platform
::
CPUPlace
(),
&
value
,
gpu_place
,
data
,
sizeof
(
T
),
dev_ctx
.
stream
());
return
value
;
}
};
template
<
typename
T
>
template
<
typename
T
>
__global__
void
AllcloseCUDAKernel
(
const
T
*
in_data
,
const
T
*
other_data
,
__global__
void
AllcloseCUDAKernel
(
const
T
*
in_data
,
const
double
rtol
,
const
double
atol
,
const
T
*
other_data
,
bool
equal_nan
,
int
num
,
bool
*
out_data
)
{
const
double
rtol
,
const
double
atol
,
bool
equal_nan
,
int
num
,
bool
*
out_data
)
{
unsigned
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
unsigned
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
bool
val
;
bool
val
;
for
(
int
i
=
idx
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
int
i
=
idx
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
...
@@ -52,16 +43,32 @@ __global__ void AllcloseCUDAKernel(const T* in_data, const T* other_data,
...
@@ -52,16 +43,32 @@ __global__ void AllcloseCUDAKernel(const T* in_data, const T* other_data,
}
}
}
}
template
<
typename
T
>
template
<
typename
T
,
typename
Context
>
struct
AllcloseFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
AllCloseKernel
(
const
Context
&
dev_ctx
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
DenseTensor
&
x
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
other
,
const
DenseTensor
&
y
,
const
double
rtol
,
const
double
atol
,
bool
equal_nan
,
const
Scalar
&
rtol
,
framework
::
Tensor
*
output
)
{
const
Scalar
&
atol
,
int
num
=
in
.
numel
();
bool
equal_nan
,
const
T
*
in_data
=
in
.
data
<
T
>
();
DenseTensor
*
out
)
{
const
T
*
other_data
=
other
.
data
<
T
>
();
PADDLE_ENFORCE_EQ
(
bool
*
out_data
=
output
->
mutable_data
<
bool
>
(
dev_ctx
.
GetPlace
());
rtol
.
dtype
(),
DataType
::
FLOAT64
,
phi
::
errors
::
InvalidArgument
(
"Input (Rtol) type must be double, but get %s."
,
rtol
.
dtype
()));
PADDLE_ENFORCE_EQ
(
atol
.
dtype
(),
DataType
::
FLOAT64
,
phi
::
errors
::
InvalidArgument
(
"Input (Atol) type must be double, but get %s."
,
atol
.
dtype
()));
const
T
*
in_data
=
x
.
data
<
T
>
();
const
T
*
other_data
=
y
.
data
<
T
>
();
auto
rtol_v
=
rtol
.
to
<
double
>
();
auto
atol_v
=
atol
.
to
<
double
>
();
bool
*
out_data
=
dev_ctx
.
template
Alloc
<
bool
>(
out
);
int
num
=
x
.
numel
();
int
block
=
1024
;
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
grid
=
(
grid
>
block
)
?
block
:
grid
;
grid
=
(
grid
>
block
)
?
block
:
grid
;
...
@@ -71,14 +78,12 @@ struct AllcloseFunctor<platform::CUDADeviceContext, T> {
...
@@ -71,14 +78,12 @@ struct AllcloseFunctor<platform::CUDADeviceContext, T> {
cudaMemset
(
out_data
,
true
,
sizeof
(
bool
));
cudaMemset
(
out_data
,
true
,
sizeof
(
bool
));
#endif
#endif
AllcloseCUDAKernel
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
AllcloseCUDAKernel
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
other_data
,
rtol
,
atol
,
equal_nan
,
num
,
out_data
);
in_data
,
other_data
,
rtol_v
,
atol_v
,
equal_nan
,
num
,
out_data
);
}
}
};
}
// namespace operators
}
// namespace phi
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
PD_REGISTER_KERNEL
(
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
allclose
,
GPU
,
ALL_LAYOUT
,
phi
::
AllCloseKernel
,
float
,
double
)
{
REGISTER_OP_CUDA_KERNEL
(
allclose
,
ops
::
AllcloseKernel
<
CUDA
,
float
>
,
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
BOOL
);
ops
::
AllcloseKernel
<
CUDA
,
double
>
);
}
paddle/phi/ops/compat/allclose_sig.cc
0 → 100644
浏览文件 @
76f87034
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
AllCloseOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
if
(
ctx
.
HasInput
(
"Rtol"
))
{
if
(
ctx
.
HasInput
(
"Atol"
))
{
return
KernelSignature
(
"allclose"
,
{
"Input"
,
"Other"
},
{
"Rtol"
,
"Atol"
,
"equal_nan"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"allclose"
,
{
"Input"
,
"Other"
},
{
"Rtol"
,
"atol"
,
"equal_nan"
},
{
"Out"
});
}
}
else
{
if
(
ctx
.
HasInput
(
"Atol"
))
{
return
KernelSignature
(
"allclose"
,
{
"Input"
,
"Other"
},
{
"rtol"
,
"Atol"
,
"equal_nan"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"allclose"
,
{
"Input"
,
"Other"
},
{
"rtol"
,
"atol"
,
"equal_nan"
},
{
"Out"
});
}
}
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
allclose
,
phi
::
AllCloseOpArgumentMapping
);
paddle/phi/tests/ops/test_op_signature.cc
浏览文件 @
76f87034
...
@@ -484,5 +484,33 @@ TEST(ARG_MAP, set_value) {
...
@@ -484,5 +484,33 @@ TEST(ARG_MAP, set_value) {
"set_value"
);
"set_value"
);
}
}
TEST
(
ARG_MAP
,
allclose
)
{
TestArgumentMappingContext
arg_case1
(
{
"Input"
,
"Other"
,
"Rtol"
},
{},
{{
"atol"
,
paddle
::
any
(
std
::
string
{
"1e-8"
})},
{
"equal_nan"
,
paddle
::
any
(
false
)}},
{
"Out"
},
{});
auto
signature1
=
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"allclose"
)(
arg_case1
);
ASSERT_EQ
(
signature1
.
name
,
"allclose"
);
auto
attr_names1
=
std
::
get
<
1
>
(
signature1
.
args
);
ASSERT_EQ
(
attr_names1
[
0
],
"Rtol"
);
TestArgumentMappingContext
arg_case2
(
{
"Input"
,
"Other"
,
"Atol"
},
{},
{{
"rtol"
,
paddle
::
any
(
std
::
string
{
"1e-5"
})},
{
"equal_nan"
,
paddle
::
any
(
false
)}},
{
"Out"
},
{});
auto
signature2
=
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"allclose"
)(
arg_case2
);
ASSERT_EQ
(
signature2
.
name
,
"allclose"
);
auto
attr_names2
=
std
::
get
<
1
>
(
signature2
.
args
);
ASSERT_EQ
(
attr_names2
[
1
],
"Atol"
);
}
}
// namespace tests
}
// namespace tests
}
// namespace phi
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录