Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
def81b4f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
def81b4f
编写于
1月 24, 2022
作者:
Z
Zhang Ting
提交者:
GitHub
1月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
unify compare functor (#39024)
上级
46823104
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
59 addition
and
60 deletion
+59
-60
paddle/fluid/operators/controlflow/compare_op.h
paddle/fluid/operators/controlflow/compare_op.h
+25
-34
paddle/fluid/operators/matrix_rank_op.cc
paddle/fluid/operators/matrix_rank_op.cc
+10
-8
paddle/fluid/operators/matrix_rank_op.cu
paddle/fluid/operators/matrix_rank_op.cu
+4
-4
paddle/fluid/operators/matrix_rank_op.h
paddle/fluid/operators/matrix_rank_op.h
+1
-10
paddle/fluid/operators/viterbi_decode_op.cu
paddle/fluid/operators/viterbi_decode_op.cu
+3
-2
paddle/fluid/operators/viterbi_decode_op.h
paddle/fluid/operators/viterbi_decode_op.h
+3
-2
python/paddle/fluid/tests/unittests/test_compare_op.py
python/paddle/fluid/tests/unittests/test_compare_op.py
+13
-0
未找到文件。
paddle/fluid/operators/controlflow/compare_op.h
浏览文件 @
def81b4f
...
...
@@ -22,49 +22,40 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
LessThanFunctor
{
using
ELEM_TYPE
=
T
;
HOSTDEVICE
bool
operator
()(
const
T
a
,
const
T
b
)
const
{
return
a
<
b
;
}
};
template
<
typename
T
>
struct
LessEqualFunctor
{
using
ELEM_TYPE
=
T
;
HOSTDEVICE
bool
operator
()(
const
T
a
,
const
T
b
)
const
{
return
a
<=
b
;
}
};
template
<
typename
T
>
struct
GreaterThanFunctor
{
using
ELEM_TYPE
=
T
;
HOSTDEVICE
bool
operator
()(
const
T
a
,
const
T
b
)
const
{
return
a
>
b
;
}
};
template
<
typename
T
>
struct
GreaterEqualFunctor
{
using
ELEM_TYPE
=
T
;
HOSTDEVICE
bool
operator
()(
const
T
a
,
const
T
b
)
const
{
return
a
>=
b
;
}
};
template
<
typename
T
>
#define COMPARE_FUNCTOR(func_name, op) \
template <typename InT, typename OutT = bool> \
struct func_name { \
using ELEM_TYPE = InT; \
HOSTDEVICE OutT operator()(const InT a, const InT b) const { \
return static_cast<OutT>(a op b); \
} \
};
COMPARE_FUNCTOR
(
LessThanFunctor
,
<
)
COMPARE_FUNCTOR
(
LessEqualFunctor
,
<=
)
COMPARE_FUNCTOR
(
GreaterThanFunctor
,
>
)
COMPARE_FUNCTOR
(
GreaterEqualFunctor
,
>=
)
#undef COMPARE_FUNCTOR
template
<
typename
InT
,
typename
OutT
=
bool
>
struct
EqualFunctor
{
using
ELEM_TYPE
=
T
;
HOSTDEVICE
bool
operator
()(
const
T
a
,
const
T
b
)
const
{
if
(
std
::
is_floating_point
<
T
>::
value
)
{
using
ELEM_TYPE
=
In
T
;
HOSTDEVICE
OutT
operator
()(
const
InT
a
,
const
In
T
b
)
const
{
if
(
std
::
is_floating_point
<
In
T
>::
value
)
{
// This branch will be optimized while compiling if T is integer. It is
// safe to cast a and b to double.
return
fabs
(
static_cast
<
double
>
(
a
-
b
))
<
1e-8
;
return
static_cast
<
OutT
>
(
fabs
(
static_cast
<
double
>
(
a
-
b
))
<
1e-8
)
;
}
else
{
return
(
a
==
b
);
return
static_cast
<
OutT
>
(
a
==
b
);
}
}
};
template
<
typename
T
>
template
<
typename
InT
,
typename
OutT
=
bool
>
struct
NotEqualFunctor
{
using
ELEM_TYPE
=
T
;
HOSTDEVICE
bool
operator
()(
const
T
a
,
const
T
b
)
const
{
return
!
EqualFunctor
<
T
>
()(
a
,
b
);
using
ELEM_TYPE
=
In
T
;
HOSTDEVICE
bool
operator
()(
const
InT
a
,
const
In
T
b
)
const
{
return
!
EqualFunctor
<
InT
,
Out
T
>
()(
a
,
b
);
}
};
...
...
paddle/fluid/operators/matrix_rank_op.cc
浏览文件 @
def81b4f
...
...
@@ -219,18 +219,20 @@ class MatrixRankCPUKernel : public framework::OpKernel<T> {
tol_tensor
.
Resize
(
detail
::
NewAxisDim
(
tol_tensor
.
dims
(),
1
));
Tensor
compare_result
;
compare_result
.
mutable_data
<
int
>
(
detail
::
NewAxisDim
(
dim_out
,
k
),
compare_result
.
mutable_data
<
int
64_t
>
(
detail
::
NewAxisDim
(
dim_out
,
k
),
context
.
GetPlace
());
int
axis
=
-
1
;
if
(
eigenvalue_tensor
.
dims
().
size
()
>=
tol_tensor
.
dims
().
size
())
{
ElementwiseComputeEx
<
GreaterThanFunctor
<
T
>
,
platform
::
CPUDeviceContext
,
T
,
int
>
(
context
,
&
eigenvalue_tensor
,
&
tol_tensor
,
axis
,
GreaterThanFunctor
<
T
>
(),
&
compare_result
);
ElementwiseComputeEx
<
GreaterThanFunctor
<
T
,
int64_t
>
,
platform
::
CPUDeviceContext
,
T
,
int
>
(
context
,
&
eigenvalue_tensor
,
&
tol_tensor
,
axis
,
GreaterThanFunctor
<
T
,
int64_t
>
(),
&
compare_result
);
}
else
{
ElementwiseComputeEx
<
LessThanFunctor
<
T
>
,
platform
::
CPUDeviceContext
,
T
,
int
>
(
context
,
&
eigenvalue_tensor
,
&
tol_tensor
,
axis
,
LessThanFunctor
<
T
>
(),
&
compare_result
);
ElementwiseComputeEx
<
LessThanFunctor
<
T
,
int64_t
>
,
platform
::
CPUDeviceContext
,
T
,
int
>
(
context
,
&
eigenvalue_tensor
,
&
tol_tensor
,
axis
,
LessThanFunctor
<
T
,
int64_t
>
(),
&
compare_result
);
}
auto
dito_int
=
math
::
DeviceIndependenceTensorOperations
<
platform
::
CPUDeviceContext
,
...
...
paddle/fluid/operators/matrix_rank_op.cu
浏览文件 @
def81b4f
...
...
@@ -129,10 +129,10 @@ class MatrixRankGPUKernel : public framework::OpKernel<T> {
compare_result
.
mutable_data
<
int64_t
>
(
detail
::
NewAxisDim
(
dim_out
,
k
),
context
.
GetPlace
());
int
axis
=
-
1
;
ElementwiseComputeEx
<
GreaterThanFunctor
<
T
>
,
platform
::
CUDADeviceContext
,
T
,
int64_t
>
(
context
,
&
eigenvalue_tensor
,
&
tol_tensor
,
axis
,
GreaterThanFunctor
<
T
>
()
,
&
compare_result
);
ElementwiseComputeEx
<
GreaterThanFunctor
<
T
,
int64_t
>
,
platform
::
CUDADeviceContext
,
T
,
int64_t
>
(
context
,
&
eigenvalue_tensor
,
&
tol_tensor
,
axis
,
GreaterThanFunctor
<
T
,
int64_t
>
(),
&
compare_result
);
auto
dito_int
=
math
::
DeviceIndependenceTensorOperations
<
platform
::
CUDADeviceContext
,
int64_t
>
(
context
);
...
...
paddle/fluid/operators/matrix_rank_op.h
浏览文件 @
def81b4f
...
...
@@ -16,6 +16,7 @@
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/controlflow/compare_op.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -46,16 +47,6 @@ static DDim RemoveLastDim(const DDim& dim) {
}
}
// namespace detail
template
<
typename
T
>
struct
GreaterThanFunctor
{
HOSTDEVICE
int
operator
()(
const
T
a
,
const
T
b
)
const
{
return
a
>
b
;
}
};
template
<
typename
T
>
struct
LessThanFunctor
{
HOSTDEVICE
int
operator
()(
const
T
a
,
const
T
b
)
const
{
return
a
<
b
;
}
};
template
<
typename
T
>
struct
GreaterElementFunctor
{
HOSTDEVICE
T
operator
()(
const
T
a
,
const
T
b
)
const
{
...
...
paddle/fluid/operators/viterbi_decode_op.cu
浏览文件 @
def81b4f
...
...
@@ -72,7 +72,8 @@ struct BinaryOperation<platform::CUDADeviceContext, BinaryFunctor, T> {
}
};
template
<
template
<
typename
T
>
typename
CompareFunctor
,
typename
T
>
template
<
template
<
typename
InT
,
typename
OutT
>
typename
CompareFunctor
,
typename
T
>
struct
GetMask
<
platform
::
CUDADeviceContext
,
CompareFunctor
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
lhs
,
const
Tensor
&
rhs
,
Tensor
*
mask
)
{
...
...
@@ -81,7 +82,7 @@ struct GetMask<platform::CUDADeviceContext, CompareFunctor, T> {
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
paddle
::
operators
::
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
int64_t
,
T
>
(
dev_ctx
,
ins
,
&
outs
,
CompareFunctor
<
int64_t
>
());
CompareFunctor
<
int64_t
,
T
>
());
}
};
...
...
paddle/fluid/operators/viterbi_decode_op.h
浏览文件 @
def81b4f
...
...
@@ -112,12 +112,13 @@ void SameDimsBinaryOP(const Tensor& lhs, const Tensor& rhs, Tensor* out) {
}
}
template
<
typename
DeviceContext
,
template
<
typename
T
>
typename
CompareFunctor
,
template
<
typename
DeviceContext
,
template
<
typename
InT
,
typename
OutT
>
typename
CompareFunctor
,
typename
T
>
struct
GetMask
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
lhs
,
const
Tensor
&
rhs
,
Tensor
*
mask
)
{
SameDimsBinaryOP
<
int64_t
,
CompareFunctor
<
int64_t
>
,
T
>
(
lhs
,
rhs
,
mask
);
SameDimsBinaryOP
<
int64_t
,
CompareFunctor
<
int64_t
,
T
>
,
T
>
(
lhs
,
rhs
,
mask
);
}
};
...
...
python/paddle/fluid/tests/unittests/test_compare_op.py
浏览文件 @
def81b4f
...
...
@@ -140,6 +140,19 @@ def create_paddle_case(op_type, callback):
self
.
assertEqual
((
out
.
numpy
()
==
self
.
real_result
).
all
(),
True
)
paddle
.
enable_static
()
def
test_not_equal
(
self
):
if
self
.
op_type
==
"not_equal"
:
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
np
.
array
([
1.2e-8
,
2
,
2
,
1
]),
dtype
=
"float32"
)
y
=
paddle
.
to_tensor
(
np
.
array
([
1.1e-8
,
2
,
2
,
1
]),
dtype
=
"float32"
)
op
=
eval
(
"paddle.%s"
%
(
self
.
op_type
))
out
=
op
(
x
,
y
)
self
.
real_result
=
np
.
array
([
0
,
0
,
0
,
0
]).
astype
(
np
.
int64
)
self
.
assertEqual
((
out
.
numpy
()
==
self
.
real_result
).
all
(),
True
)
paddle
.
enable_static
()
def
test_assert
(
self
):
def
test_dynamic_api_string
(
self
):
if
self
.
op_type
==
"equal"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录