Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
54a4696f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
54a4696f
编写于
1月 22, 2018
作者:
Y
Yu Yang
提交者:
GitHub
1月 22, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7660 from reyoung/feature/compare_op_use_elemwise
Make compare_op reuse elemwise_op_funcs
上级
430fdc52
2024489b
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
15 addition
and
38 deletion
+15
-38
paddle/operators/compare_op.cc
paddle/operators/compare_op.cc
+5
-6
paddle/operators/compare_op.cu
paddle/operators/compare_op.cu
+0
-4
paddle/operators/compare_op.h
paddle/operators/compare_op.h
+2
-20
paddle/operators/elementwise_op_function.h
paddle/operators/elementwise_op_function.h
+8
-6
python/paddle/v2/fluid/tests/test_compare_op.py
python/paddle/v2/fluid/tests/test_compare_op.py
+0
-2
未找到文件。
paddle/operators/compare_op.cc
浏览文件 @
54a4696f
...
@@ -39,6 +39,11 @@ N-dim tensor. X and Y could be any type. The each element of the Out tensor is
...
@@ -39,6 +39,11 @@ N-dim tensor. X and Y could be any type. The each element of the Out tensor is
calculated by %s
calculated by %s
)DOC"
,
)DOC"
,
comment
.
type
,
comment
.
equation
));
comment
.
type
,
comment
.
equation
));
AddAttr
<
int
>
(
"axis"
,
"(int, default -1). The start dimension index "
"for broadcasting Y onto X."
)
.
SetDefault
(
-
1
)
.
EqualGreaterThan
(
-
1
);
}
}
};
};
...
@@ -95,11 +100,5 @@ REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
...
@@ -95,11 +100,5 @@ REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
REGISTER_LOGICAL_KERNEL
(
less_than
,
CPU
,
paddle
::
operators
::
LessThanFunctor
);
REGISTER_LOGICAL_KERNEL
(
less_than
,
CPU
,
paddle
::
operators
::
LessThanFunctor
);
REGISTER_LOGICAL_OP
(
less_equal
,
"Out = X <= Y"
);
REGISTER_LOGICAL_OP
(
less_equal
,
"Out = X <= Y"
);
REGISTER_LOGICAL_KERNEL
(
less_equal
,
CPU
,
paddle
::
operators
::
LessEqualFunctor
);
REGISTER_LOGICAL_KERNEL
(
less_equal
,
CPU
,
paddle
::
operators
::
LessEqualFunctor
);
REGISTER_LOGICAL_OP
(
greater_than
,
"Out = X > Y"
);
REGISTER_LOGICAL_KERNEL
(
greater_than
,
CPU
,
paddle
::
operators
::
GreaterThanFunctor
);
REGISTER_LOGICAL_OP
(
greater_equal
,
"Out = X >= Y"
);
REGISTER_LOGICAL_KERNEL
(
greater_equal
,
CPU
,
paddle
::
operators
::
GreaterEqualFunctor
);
REGISTER_LOGICAL_OP
(
equal
,
"Out = X == Y"
);
REGISTER_LOGICAL_OP
(
equal
,
"Out = X == Y"
);
REGISTER_LOGICAL_KERNEL
(
equal
,
CPU
,
paddle
::
operators
::
EqualFunctor
);
REGISTER_LOGICAL_KERNEL
(
equal
,
CPU
,
paddle
::
operators
::
EqualFunctor
);
paddle/operators/compare_op.cu
浏览文件 @
54a4696f
...
@@ -16,8 +16,4 @@ limitations under the License. */
...
@@ -16,8 +16,4 @@ limitations under the License. */
REGISTER_LOGICAL_KERNEL
(
less_than
,
CUDA
,
paddle
::
operators
::
LessThanFunctor
);
REGISTER_LOGICAL_KERNEL
(
less_than
,
CUDA
,
paddle
::
operators
::
LessThanFunctor
);
REGISTER_LOGICAL_KERNEL
(
less_equal
,
CUDA
,
paddle
::
operators
::
LessEqualFunctor
);
REGISTER_LOGICAL_KERNEL
(
less_equal
,
CUDA
,
paddle
::
operators
::
LessEqualFunctor
);
REGISTER_LOGICAL_KERNEL
(
greater_than
,
CUDA
,
paddle
::
operators
::
GreaterThanFunctor
);
REGISTER_LOGICAL_KERNEL
(
greater_equal
,
CUDA
,
paddle
::
operators
::
GreaterEqualFunctor
);
REGISTER_LOGICAL_KERNEL
(
equal
,
CUDA
,
paddle
::
operators
::
EqualFunctor
);
REGISTER_LOGICAL_KERNEL
(
equal
,
CUDA
,
paddle
::
operators
::
EqualFunctor
);
paddle/operators/compare_op.h
浏览文件 @
54a4696f
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <math.h>
#include <math.h>
#include <type_traits>
#include <type_traits>
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/elementwise_op_function.h"
#include "paddle/platform/transform.h"
#include "paddle/platform/transform.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -33,18 +34,6 @@ struct LessEqualFunctor {
...
@@ -33,18 +34,6 @@ struct LessEqualFunctor {
HOSTDEVICE
bool
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
<=
b
;
}
HOSTDEVICE
bool
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
<=
b
;
}
};
};
template
<
typename
T
>
struct
GreaterThanFunctor
{
using
ELEM_TYPE
=
T
;
HOSTDEVICE
bool
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
>
b
;
}
};
template
<
typename
T
>
struct
GreaterEqualFunctor
{
using
ELEM_TYPE
=
T
;
HOSTDEVICE
bool
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
>=
b
;
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
EqualFunctor
{
struct
EqualFunctor
{
using
ELEM_TYPE
=
T
;
using
ELEM_TYPE
=
T
;
...
@@ -65,14 +54,7 @@ class CompareOpKernel
...
@@ -65,14 +54,7 @@ class CompareOpKernel
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
using
T
=
typename
Functor
::
ELEM_TYPE
;
using
T
=
typename
Functor
::
ELEM_TYPE
;
auto
*
x
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
ElementwiseComputeEx
<
Functor
,
DeviceContext
,
T
,
bool
>
(
context
);
auto
*
y
=
context
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
Functor
binary_func
;
platform
::
Transform
<
DeviceContext
>
trans
;
trans
(
context
.
template
device_context
<
DeviceContext
>(),
x
->
data
<
T
>
(),
x
->
data
<
T
>
()
+
x
->
numel
(),
y
->
data
<
T
>
(),
out
->
mutable_data
<
bool
>
(
context
.
GetPlace
()),
binary_func
);
}
}
};
};
...
...
paddle/operators/elementwise_op_function.h
浏览文件 @
54a4696f
...
@@ -176,14 +176,15 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
...
@@ -176,14 +176,15 @@ class MidWiseTransformIterator<T, platform::CUDADeviceContext>
};
};
#endif
#endif
template
<
typename
Functor
,
typename
T
,
typename
DeviceContext
>
template
<
typename
Functor
,
typename
T
,
typename
DeviceContext
,
typename
OutType
=
T
>
class
TransformFunctor
{
class
TransformFunctor
{
public:
public:
TransformFunctor
(
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
TransformFunctor
(
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
,
const
DeviceContext
&
ctx
,
Functor
func
)
framework
::
Tensor
*
z
,
const
DeviceContext
&
ctx
,
Functor
func
)
:
x_
(
x
->
data
<
T
>
()),
:
x_
(
x
->
data
<
T
>
()),
y_
(
y
->
data
<
T
>
()),
y_
(
y
->
data
<
T
>
()),
z_
(
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
())),
z_
(
z
->
mutable_data
<
OutType
>
(
ctx
.
GetPlace
())),
nx_
(
x
->
numel
()),
nx_
(
x
->
numel
()),
ctx_
(
ctx
),
ctx_
(
ctx
),
func_
(
func
)
{}
func_
(
func
)
{}
...
@@ -208,7 +209,7 @@ class TransformFunctor {
...
@@ -208,7 +209,7 @@ class TransformFunctor {
private:
private:
const
T
*
x_
;
const
T
*
x_
;
const
T
*
y_
;
const
T
*
y_
;
T
*
z_
;
OutType
*
z_
;
int64_t
nx_
;
int64_t
nx_
;
const
DeviceContext
&
ctx_
;
const
DeviceContext
&
ctx_
;
Functor
func_
;
Functor
func_
;
...
@@ -364,15 +365,16 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
...
@@ -364,15 +365,16 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
}
}
}
}
template
<
typename
Functor
,
typename
DeviceContext
,
typename
T
>
template
<
typename
Functor
,
typename
DeviceContext
,
typename
T
,
typename
OutType
=
T
>
void
ElementwiseComputeEx
(
const
framework
::
ExecutionContext
&
ctx
)
{
void
ElementwiseComputeEx
(
const
framework
::
ExecutionContext
&
ctx
)
{
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
z
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
auto
*
z
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
z
->
mutable_data
<
OutType
>
(
ctx
.
GetPlace
());
TransformFunctor
<
Functor
,
T
,
DeviceContext
>
functor
(
TransformFunctor
<
Functor
,
T
,
DeviceContext
,
OutType
>
functor
(
x
,
y
,
z
,
ctx
.
template
device_context
<
DeviceContext
>(),
Functor
());
x
,
y
,
z
,
ctx
.
template
device_context
<
DeviceContext
>(),
Functor
());
auto
x_dims
=
x
->
dims
();
auto
x_dims
=
x
->
dims
();
...
...
python/paddle/v2/fluid/tests/test_compare_op.py
浏览文件 @
54a4696f
...
@@ -38,8 +38,6 @@ def create_test_class(op_type, typename, callback):
...
@@ -38,8 +38,6 @@ def create_test_class(op_type, typename, callback):
for
_type_name
in
{
'float32'
,
'float64'
,
'int32'
,
'int64'
}:
for
_type_name
in
{
'float32'
,
'float64'
,
'int32'
,
'int64'
}:
create_test_class
(
'less_than'
,
_type_name
,
lambda
_a
,
_b
:
_a
<
_b
)
create_test_class
(
'less_than'
,
_type_name
,
lambda
_a
,
_b
:
_a
<
_b
)
create_test_class
(
'less_equal'
,
_type_name
,
lambda
_a
,
_b
:
_a
<=
_b
)
create_test_class
(
'less_equal'
,
_type_name
,
lambda
_a
,
_b
:
_a
<=
_b
)
create_test_class
(
'greater_than'
,
_type_name
,
lambda
_a
,
_b
:
_a
>
_b
)
create_test_class
(
'greater_equal'
,
_type_name
,
lambda
_a
,
_b
:
_a
>=
_b
)
create_test_class
(
'equal'
,
_type_name
,
lambda
_a
,
_b
:
_a
==
_b
)
create_test_class
(
'equal'
,
_type_name
,
lambda
_a
,
_b
:
_a
==
_b
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录