Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
7f22a6d1
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7f22a6d1
编写于
11月 08, 2017
作者:
Y
Yu Yang
提交者:
GitHub
11月 08, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5465 from reyoung/feature/compare_op_support_cpu
CompareOp's kernel device type is decided by input tensor place
上级
2a76b42e
3187451a
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
26 addition
and
14 deletion
+26
-14
paddle/operators/compare_op.cc
paddle/operators/compare_op.cc
+26
-10
paddle/platform/transform.h
paddle/platform/transform.h
+0
-4
未找到文件。
paddle/operators/compare_op.cc
浏览文件 @
7f22a6d1
...
...
@@ -14,6 +14,7 @@
#include "paddle/operators/compare_op.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
OpComment
>
...
...
@@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase {
}
};
class
CompareOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
framework
::
OpKernelType
GetKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
framework
::
OpKernelType
kt
=
OperatorWithKernel
::
GetKernelType
(
ctx
);
// CompareOp kernel's device type is decided by input tensor place
kt
.
place_
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
)
->
place
();
return
kt
;
}
};
}
// namespace operators
}
// namespace paddle
#define REGISTER_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OP_WITH_KERNEL( \
op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
#define REGISTER_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \
static char type[]; \
static char equation[]; \
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOp, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker);
REGISTER_LOGICAL_OP
(
less_than
,
"Out = X < Y"
);
...
...
paddle/platform/transform.h
浏览文件 @
7f22a6d1
...
...
@@ -49,8 +49,6 @@ struct Transform<platform::CPUPlace> {
template
<
typename
InputIter
,
typename
OutputIter
,
typename
UnaryOperation
>
void
operator
()(
const
DeviceContext
&
context
,
InputIter
first
,
InputIter
last
,
OutputIter
result
,
UnaryOperation
op
)
{
auto
place
=
context
.
GetPlace
();
PADDLE_ENFORCE
(
is_cpu_place
(
place
),
"It must use CPU place."
);
std
::
transform
(
first
,
last
,
result
,
op
);
}
...
...
@@ -59,8 +57,6 @@ struct Transform<platform::CPUPlace> {
void
operator
()(
const
DeviceContext
&
context
,
InputIter1
first1
,
InputIter1
last1
,
InputIter2
first2
,
OutputIter
result
,
BinaryOperation
op
)
{
auto
place
=
context
.
GetPlace
();
PADDLE_ENFORCE
(
is_cpu_place
(
place
),
"It must use CPU place."
);
std
::
transform
(
first1
,
last1
,
first2
,
result
,
op
);
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录