From 3187451ae7dc8f8e1155e952dc725d321967a85a Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Tue, 7 Nov 2017 20:23:09 -0800 Subject: [PATCH] CompareOp's kernel device type is decided by input tensor place CompareOp can run on CPU even other operators are running on GPU, since opeatations like comparing control flags should be performed only on CPU --- paddle/operators/compare_op.cc | 36 ++++++++++++++++++++++++---------- paddle/platform/transform.h | 4 ---- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/paddle/operators/compare_op.cc b/paddle/operators/compare_op.cc index 8b425d14df..716b5ee92d 100644 --- a/paddle/operators/compare_op.cc +++ b/paddle/operators/compare_op.cc @@ -14,6 +14,7 @@ #include "paddle/operators/compare_op.h" #include "paddle/framework/op_registry.h" + namespace paddle { namespace operators { template @@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase { } }; +class CompareOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext &ctx) const override { + framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx); + // CompareOp kernel's device type is decided by input tensor place + kt.place_ = ctx.Input("X")->place(); + return kt; + } +}; + } // namespace operators } // namespace paddle -#define REGISTER_LOGICAL_OP(op_type, _equation) \ - struct _##op_type##Comment { \ - static char type[]; \ - static char equation[]; \ - }; \ - char _##op_type##Comment::type[]{#op_type}; \ - char _##op_type##Comment::equation[]{_equation}; \ - REGISTER_OP_WITH_KERNEL( \ - op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ - ::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \ +#define REGISTER_LOGICAL_OP(op_type, _equation) \ + struct _##op_type##Comment { \ + static char type[]; \ + static char equation[]; \ + }; \ + char _##op_type##Comment::type[]{#op_type}; \ + char _##op_type##Comment::equation[]{_equation}; \ + REGISTER_OPERATOR( \ + op_type, ::paddle::operators::CompareOp, \ + ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ + ::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \ ::paddle::framework::EmptyGradOpMaker); REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); diff --git a/paddle/platform/transform.h b/paddle/platform/transform.h index f196868c72..bb9d59ec0a 100644 --- a/paddle/platform/transform.h +++ b/paddle/platform/transform.h @@ -49,8 +49,6 @@ struct Transform { template void operator()(const DeviceContext& context, InputIter first, InputIter last, OutputIter result, UnaryOperation op) { - auto place = context.GetPlace(); - PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place."); std::transform(first, last, result, op); } @@ -59,8 +57,6 @@ struct Transform { void operator()(const DeviceContext& context, InputIter1 first1, InputIter1 last1, InputIter2 first2, OutputIter result, BinaryOperation op) { - auto place = context.GetPlace(); - PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place."); std::transform(first1, last1, first2, result, op); } }; -- GitLab