diff --git a/paddle/operators/compare_op.cc b/paddle/operators/compare_op.cc index 8b425d14df3bc484437dc72f29abf13b887006bd..716b5ee92d0d8737d2069460f53989f691ff7c77 100644 --- a/paddle/operators/compare_op.cc +++ b/paddle/operators/compare_op.cc @@ -14,6 +14,7 @@ #include "paddle/operators/compare_op.h" #include "paddle/framework/op_registry.h" + namespace paddle { namespace operators { template @@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase { } }; +class CompareOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext &ctx) const override { + framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx); + // CompareOp kernel's device type is decided by input tensor place + kt.place_ = ctx.Input("X")->place(); + return kt; + } +}; + } // namespace operators } // namespace paddle -#define REGISTER_LOGICAL_OP(op_type, _equation) \ - struct _##op_type##Comment { \ - static char type[]; \ - static char equation[]; \ - }; \ - char _##op_type##Comment::type[]{#op_type}; \ - char _##op_type##Comment::equation[]{_equation}; \ - REGISTER_OP_WITH_KERNEL( \ - op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ - ::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \ +#define REGISTER_LOGICAL_OP(op_type, _equation) \ + struct _##op_type##Comment { \ + static char type[]; \ + static char equation[]; \ + }; \ + char _##op_type##Comment::type[]{#op_type}; \ + char _##op_type##Comment::equation[]{_equation}; \ + REGISTER_OPERATOR( \ + op_type, ::paddle::operators::CompareOp, \ + ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ + ::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \ ::paddle::framework::EmptyGradOpMaker); REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); diff --git a/paddle/platform/transform.h b/paddle/platform/transform.h index f196868c725cbb91b3df710260c5b60f14d53f37..bb9d59ec0a18ce013632f128c9b5d230255f1ac4 100644 --- a/paddle/platform/transform.h +++ b/paddle/platform/transform.h @@ -49,8 +49,6 @@ struct Transform { template void operator()(const DeviceContext& context, InputIter first, InputIter last, OutputIter result, UnaryOperation op) { - auto place = context.GetPlace(); - PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place."); std::transform(first, last, result, op); } @@ -59,8 +57,6 @@ struct Transform { void operator()(const DeviceContext& context, InputIter1 first1, InputIter1 last1, InputIter2 first2, OutputIter result, BinaryOperation op) { - auto place = context.GetPlace(); - PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place."); std::transform(first1, last1, first2, result, op); } };