未验证 提交 7f22a6d1 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #5465 from reyoung/feature/compare_op_support_cpu

CompareOp's kernel device type is decided by input tensor place
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/operators/compare_op.h" #include "paddle/operators/compare_op.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename OpComment> template <typename OpComment>
...@@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase { ...@@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase {
} }
}; };
class CompareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define REGISTER_LOGICAL_OP(op_type, _equation) \ #define REGISTER_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \ struct _##op_type##Comment { \
static char type[]; \ static char type[]; \
static char equation[]; \ static char equation[]; \
}; \ }; \
char _##op_type##Comment::type[]{#op_type}; \ char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \ char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OP_WITH_KERNEL( \ REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ op_type, ::paddle::operators::CompareOp, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \ ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker); ::paddle::framework::EmptyGradOpMaker);
REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
......
...@@ -49,8 +49,6 @@ struct Transform<platform::CPUPlace> { ...@@ -49,8 +49,6 @@ struct Transform<platform::CPUPlace> {
template <typename InputIter, typename OutputIter, typename UnaryOperation> template <typename InputIter, typename OutputIter, typename UnaryOperation>
void operator()(const DeviceContext& context, InputIter first, InputIter last, void operator()(const DeviceContext& context, InputIter first, InputIter last,
OutputIter result, UnaryOperation op) { OutputIter result, UnaryOperation op) {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
std::transform(first, last, result, op); std::transform(first, last, result, op);
} }
...@@ -59,8 +57,6 @@ struct Transform<platform::CPUPlace> { ...@@ -59,8 +57,6 @@ struct Transform<platform::CPUPlace> {
void operator()(const DeviceContext& context, InputIter1 first1, void operator()(const DeviceContext& context, InputIter1 first1,
InputIter1 last1, InputIter2 first2, OutputIter result, InputIter1 last1, InputIter2 first2, OutputIter result,
BinaryOperation op) { BinaryOperation op) {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
std::transform(first1, last1, first2, result, op); std::transform(first1, last1, first2, result, op);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册