Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
76f87034
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
76f87034
编写于
3月 12, 2022
作者:
C
Chen Weihang
提交者:
GitHub
3月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi] Move allclose op kernel into phi (#40469)
* move allclose kernel * remove allclose op kernel * fix coverage failed
上级
39de9b8a
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
276 addition
and
131 deletion
+276
-131
paddle/fluid/operators/allclose_op.cc
paddle/fluid/operators/allclose_op.cc
+1
-38
paddle/fluid/operators/allclose_op.h
paddle/fluid/operators/allclose_op.h
+0
-93
paddle/phi/api/lib/utils/tensor_utils.cc
paddle/phi/api/lib/utils/tensor_utils.cc
+7
-0
paddle/phi/kernels/allclose_kernel.h
paddle/phi/kernels/allclose_kernel.h
+31
-0
paddle/phi/kernels/cpu/allclose_kernel.cc
paddle/phi/kernels/cpu/allclose_kernel.cc
+71
-0
paddle/phi/kernels/gpu/allclose_kernel.cu
paddle/phi/kernels/gpu/allclose_kernel.cu
+89
-0
paddle/phi/ops/compat/allclose_sig.cc
paddle/phi/ops/compat/allclose_sig.cc
+49
-0
paddle/phi/tests/ops/test_op_signature.cc
paddle/phi/tests/ops/test_op_signature.cc
+28
-0
未找到文件。
paddle/fluid/operators/allclose_op.cc
浏览文件 @
76f87034
...
...
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/allclose_op.h"
#include <cmath>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/operator.h"
...
...
@@ -23,41 +23,6 @@
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
GetTensorValue
<
platform
::
CPUDeviceContext
,
T
>
{
T
operator
()(
const
platform
::
CPUDeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
tensor
)
const
{
return
*
(
tensor
.
data
<
T
>
());
}
};
template
<
typename
T
>
struct
AllcloseFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CPUDeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
other
,
const
double
rtol
,
const
double
atol
,
bool
equal_nan
,
framework
::
Tensor
*
output
)
{
auto
*
in_a
=
in
.
data
<
T
>
();
auto
*
in_b
=
other
.
data
<
T
>
();
auto
*
out_data
=
output
->
mutable_data
<
bool
>
(
ctx
.
GetPlace
());
auto
num
=
in
.
numel
();
*
out_data
=
true
;
for
(
int
i
=
0
;
i
<
num
;
i
++
)
{
const
T
a
=
in_a
[
i
],
b
=
in_b
[
i
];
bool
val
;
if
(
std
::
isnan
(
a
)
||
std
::
isnan
(
b
))
{
val
=
equal_nan
&&
std
::
isnan
(
a
)
==
std
::
isnan
(
b
);
}
else
{
T
left
=
(
a
>
b
?
a
-
b
:
b
-
a
);
T
right
=
atol
+
(
b
>
0
?
rtol
*
b
:
(
-
rtol
)
*
b
);
T
diff
=
(
left
>
right
?
left
-
right
:
right
-
left
);
val
=
a
==
b
||
left
<=
right
||
diff
<=
1e-15
;
}
*
out_data
&=
val
;
}
}
};
class
AllcloseOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
...
...
@@ -157,8 +122,6 @@ REGISTER_OPERATOR(
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
AllcloseOpVarTypeInference
);
REGISTER_OP_CPU_KERNEL
(
allclose
,
ops
::
AllcloseKernel
<
CPU
,
float
>
,
ops
::
AllcloseKernel
<
CPU
,
double
>
);
/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION
(
allclose
)
...
...
paddle/fluid/operators/allclose_op.h
已删除
100644 → 0
浏览文件 @
39de9b8a
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
>
struct
GetTensorValue
{
T
operator
()(
const
platform
::
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
tensor
)
const
;
};
template
<
typename
DeviceContext
,
typename
T
>
struct
AllcloseFunctor
{
void
operator
()(
const
DeviceContext
&
ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
other
,
const
float
rtol
,
const
float
atol
,
bool
equal_nan
,
framework
::
Tensor
*
output
);
};
template
<
typename
DeviceContext
,
typename
T
>
class
AllcloseKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
// get attrs
bool
equal_nan
=
ctx
.
Attr
<
bool
>
(
"equal_nan"
);
// get input/output
const
auto
*
input
=
ctx
.
Input
<
Tensor
>
(
"Input"
);
const
auto
*
other
=
ctx
.
Input
<
Tensor
>
(
"Other"
);
auto
*
out
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
double
rtol_v
=
std
::
stod
(
ctx
.
Attr
<
std
::
string
>
(
"rtol"
));
double
atol_v
=
std
::
stod
(
ctx
.
Attr
<
std
::
string
>
(
"atol"
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
GetTensorValue
<
DeviceContext
,
double
>
get_tensor_value
;
if
(
ctx
.
HasInput
(
"Rtol"
))
{
const
auto
*
rtol
=
ctx
.
Input
<
Tensor
>
(
"Rtol"
);
PADDLE_ENFORCE_EQ
(
rtol
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Input(Rtol) size must be 1, but get %d."
,
rtol
->
numel
()));
PADDLE_ENFORCE_EQ
(
framework
::
TransToProtoVarType
(
rtol
->
dtype
()),
framework
::
proto
::
VarType
::
FP64
,
platform
::
errors
::
InvalidArgument
(
"Input(Rtol) type must be double, but get %s."
,
framework
::
DataTypeToString
(
framework
::
TransToProtoVarType
(
rtol
->
dtype
()))));
rtol_v
=
get_tensor_value
(
dev_ctx
,
*
rtol
);
}
if
(
ctx
.
HasInput
(
"Atol"
))
{
const
auto
*
atol
=
ctx
.
Input
<
Tensor
>
(
"Atol"
);
PADDLE_ENFORCE_EQ
(
atol
->
numel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Input(Atol) size must be 1, but get %d"
,
atol
->
numel
()));
PADDLE_ENFORCE_EQ
(
framework
::
TransToProtoVarType
(
atol
->
dtype
()),
framework
::
proto
::
VarType
::
FP64
,
platform
::
errors
::
InvalidArgument
(
"Input(Atol) type must be double, but get %s"
,
framework
::
DataTypeToString
(
framework
::
TransToProtoVarType
(
atol
->
dtype
()))));
atol_v
=
get_tensor_value
(
dev_ctx
,
*
atol
);
}
AllcloseFunctor
<
DeviceContext
,
T
>
()(
dev_ctx
,
*
input
,
*
other
,
rtol_v
,
atol_v
,
equal_nan
,
out
);
}
};
}
// namespace operators
}
// namespace paddle
paddle/phi/api/lib/utils/tensor_utils.cc
浏览文件 @
76f87034
...
...
@@ -40,6 +40,13 @@ phi::Scalar MakePhiScalarFromVar(const framework::Variable& variable) {
auto
expected_place
=
phi
::
TransToPhiPlace
(
phi
::
Backend
::
CPU
);
if
(
variable
.
IsType
<
framework
::
LoDTensor
>
())
{
const
auto
&
tensor
=
variable
.
Get
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
tensor
.
numel
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"The DenseTensor used to construct "
"the Scalar contains more than 1 "
"value, it contains `%d` values."
,
tensor
.
numel
()));
if
(
!
platform
::
is_same_place
(
tensor
.
place
(),
expected_place
))
{
framework
::
LoDTensor
tmp_tensor
;
framework
::
TensorCopySync
(
tensor
,
expected_place
,
&
tmp_tensor
);
...
...
paddle/phi/kernels/allclose_kernel.h
0 → 100644
浏览文件 @
76f87034
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
AllCloseKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
Scalar
&
rtol
,
const
Scalar
&
atol
,
bool
equal_nan
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/cpu/allclose_kernel.cc
0 → 100644
浏览文件 @
76f87034
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/allclose_kernel.h"
#include <cmath>
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
AllCloseKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
Scalar
&
rtol
,
const
Scalar
&
atol
,
bool
equal_nan
,
DenseTensor
*
out
)
{
PADDLE_ENFORCE_EQ
(
rtol
.
dtype
(),
DataType
::
FLOAT64
,
phi
::
errors
::
InvalidArgument
(
"Input (Rtol) type must be double, but get %s."
,
rtol
.
dtype
()));
PADDLE_ENFORCE_EQ
(
atol
.
dtype
(),
DataType
::
FLOAT64
,
phi
::
errors
::
InvalidArgument
(
"Input (Atol) type must be double, but get %s."
,
atol
.
dtype
()));
auto
*
in_a
=
x
.
data
<
T
>
();
auto
*
in_b
=
y
.
data
<
T
>
();
auto
rtol_v
=
rtol
.
to
<
double
>
();
auto
atol_v
=
atol
.
to
<
double
>
();
auto
*
out_data
=
dev_ctx
.
template
Alloc
<
bool
>(
out
);
*
out_data
=
true
;
auto
num
=
x
.
numel
();
for
(
int64_t
i
=
0
;
i
<
num
;
++
i
)
{
const
T
a
=
in_a
[
i
],
b
=
in_b
[
i
];
bool
val
;
if
(
std
::
isnan
(
a
)
||
std
::
isnan
(
b
))
{
val
=
equal_nan
&&
std
::
isnan
(
a
)
==
std
::
isnan
(
b
);
}
else
{
T
left
=
(
a
>
b
?
a
-
b
:
b
-
a
);
T
right
=
atol_v
+
(
b
>
0
?
rtol_v
*
b
:
(
-
rtol_v
)
*
b
);
T
diff
=
(
left
>
right
?
left
-
right
:
right
-
left
);
val
=
a
==
b
||
left
<=
right
||
diff
<=
1e-15
;
}
*
out_data
&=
val
;
}
}
}
// namespace phi
PD_REGISTER_KERNEL
(
allclose
,
CPU
,
ALL_LAYOUT
,
phi
::
AllCloseKernel
,
float
,
double
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
BOOL
);
}
paddle/
fluid/operators/allclose_op
.cu
→
paddle/
phi/kernels/gpu/allclose_kernel
.cu
浏览文件 @
76f87034
// Copyright (c) 20
18
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 20
22
PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -12,30 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/allclose_op.h"
#include "paddle/phi/kernels/allclose_kernel.h"
namespace
paddle
{
namespace
operators
{
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
template
<
typename
T
>
struct
GetTensorValue
<
platform
::
CUDADeviceContext
,
T
>
{
T
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
tensor
)
const
{
const
T
*
data
=
tensor
.
data
<
T
>
();
T
value
;
const
auto
gpu_place
=
dev_ctx
.
GetPlace
();
memory
::
Copy
(
platform
::
CPUPlace
(),
&
value
,
gpu_place
,
data
,
sizeof
(
T
),
dev_ctx
.
stream
());
return
value
;
}
};
namespace
phi
{
template
<
typename
T
>
__global__
void
AllcloseCUDAKernel
(
const
T
*
in_data
,
const
T
*
other_data
,
const
double
rtol
,
const
double
atol
,
bool
equal_nan
,
int
num
,
bool
*
out_data
)
{
__global__
void
AllcloseCUDAKernel
(
const
T
*
in_data
,
const
T
*
other_data
,
const
double
rtol
,
const
double
atol
,
bool
equal_nan
,
int
num
,
bool
*
out_data
)
{
unsigned
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
bool
val
;
for
(
int
i
=
idx
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
...
...
@@ -52,33 +43,47 @@ __global__ void AllcloseCUDAKernel(const T* in_data, const T* other_data,
}
}
template
<
typename
T
>
struct
AllcloseFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
const
framework
::
Tensor
&
in
,
const
framework
::
Tensor
&
other
,
const
double
rtol
,
const
double
atol
,
bool
equal_nan
,
framework
::
Tensor
*
output
)
{
int
num
=
in
.
numel
();
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
other_data
=
other
.
data
<
T
>
();
bool
*
out_data
=
output
->
mutable_data
<
bool
>
(
dev_ctx
.
GetPlace
());
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
grid
=
(
grid
>
block
)
?
block
:
grid
;
template
<
typename
T
,
typename
Context
>
void
AllCloseKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
Scalar
&
rtol
,
const
Scalar
&
atol
,
bool
equal_nan
,
DenseTensor
*
out
)
{
PADDLE_ENFORCE_EQ
(
rtol
.
dtype
(),
DataType
::
FLOAT64
,
phi
::
errors
::
InvalidArgument
(
"Input (Rtol) type must be double, but get %s."
,
rtol
.
dtype
()));
PADDLE_ENFORCE_EQ
(
atol
.
dtype
(),
DataType
::
FLOAT64
,
phi
::
errors
::
InvalidArgument
(
"Input (Atol) type must be double, but get %s."
,
atol
.
dtype
()));
const
T
*
in_data
=
x
.
data
<
T
>
();
const
T
*
other_data
=
y
.
data
<
T
>
();
auto
rtol_v
=
rtol
.
to
<
double
>
();
auto
atol_v
=
atol
.
to
<
double
>
();
bool
*
out_data
=
dev_ctx
.
template
Alloc
<
bool
>(
out
);
int
num
=
x
.
numel
();
int
block
=
1024
;
int
grid
=
(
block
-
1
+
num
)
/
block
;
grid
=
(
grid
>
block
)
?
block
:
grid
;
#ifdef PADDLE_WITH_HIP
hipMemset
(
out_data
,
true
,
sizeof
(
bool
));
hipMemset
(
out_data
,
true
,
sizeof
(
bool
));
#else
cudaMemset
(
out_data
,
true
,
sizeof
(
bool
));
cudaMemset
(
out_data
,
true
,
sizeof
(
bool
));
#endif
AllcloseCUDAKernel
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
other_data
,
rtol
,
atol
,
equal_nan
,
num
,
out_data
);
}
};
AllcloseCUDAKernel
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
other_data
,
rtol_v
,
atol_v
,
equal_nan
,
num
,
out_data
);
}
}
// namespace operators
}
// namespace paddle
}
// namespace phi
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
allclose
,
ops
::
AllcloseKernel
<
CUDA
,
float
>
,
ops
::
AllcloseKernel
<
CUDA
,
double
>
);
PD_REGISTER_KERNEL
(
allclose
,
GPU
,
ALL_LAYOUT
,
phi
::
AllCloseKernel
,
float
,
double
)
{
kernel
->
OutputAt
(
0
).
SetDataType
(
phi
::
DataType
::
BOOL
);
}
paddle/phi/ops/compat/allclose_sig.cc
0 → 100644
浏览文件 @
76f87034
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
AllCloseOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
if
(
ctx
.
HasInput
(
"Rtol"
))
{
if
(
ctx
.
HasInput
(
"Atol"
))
{
return
KernelSignature
(
"allclose"
,
{
"Input"
,
"Other"
},
{
"Rtol"
,
"Atol"
,
"equal_nan"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"allclose"
,
{
"Input"
,
"Other"
},
{
"Rtol"
,
"atol"
,
"equal_nan"
},
{
"Out"
});
}
}
else
{
if
(
ctx
.
HasInput
(
"Atol"
))
{
return
KernelSignature
(
"allclose"
,
{
"Input"
,
"Other"
},
{
"rtol"
,
"Atol"
,
"equal_nan"
},
{
"Out"
});
}
else
{
return
KernelSignature
(
"allclose"
,
{
"Input"
,
"Other"
},
{
"rtol"
,
"atol"
,
"equal_nan"
},
{
"Out"
});
}
}
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
allclose
,
phi
::
AllCloseOpArgumentMapping
);
paddle/phi/tests/ops/test_op_signature.cc
浏览文件 @
76f87034
...
...
@@ -484,5 +484,33 @@ TEST(ARG_MAP, set_value) {
"set_value"
);
}
TEST
(
ARG_MAP
,
allclose
)
{
TestArgumentMappingContext
arg_case1
(
{
"Input"
,
"Other"
,
"Rtol"
},
{},
{{
"atol"
,
paddle
::
any
(
std
::
string
{
"1e-8"
})},
{
"equal_nan"
,
paddle
::
any
(
false
)}},
{
"Out"
},
{});
auto
signature1
=
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"allclose"
)(
arg_case1
);
ASSERT_EQ
(
signature1
.
name
,
"allclose"
);
auto
attr_names1
=
std
::
get
<
1
>
(
signature1
.
args
);
ASSERT_EQ
(
attr_names1
[
0
],
"Rtol"
);
TestArgumentMappingContext
arg_case2
(
{
"Input"
,
"Other"
,
"Atol"
},
{},
{{
"rtol"
,
paddle
::
any
(
std
::
string
{
"1e-5"
})},
{
"equal_nan"
,
paddle
::
any
(
false
)}},
{
"Out"
},
{});
auto
signature2
=
OpUtilsMap
::
Instance
().
GetArgumentMappingFn
(
"allclose"
)(
arg_case2
);
ASSERT_EQ
(
signature2
.
name
,
"allclose"
);
auto
attr_names2
=
std
::
get
<
1
>
(
signature2
.
args
);
ASSERT_EQ
(
attr_names2
[
1
],
"Atol"
);
}
}
// namespace tests
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录