未验证 提交 3187451a 编写于 作者: Y Yang Yu

CompareOp's kernel device type is decided by input tensor place

CompareOp can run on CPU even other operators are running on GPU, since
opeatations like comparing control flags should be performed only on CPU
上级 c365c61a
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/operators/compare_op.h" #include "paddle/operators/compare_op.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename OpComment> template <typename OpComment>
...@@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase { ...@@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase {
} }
}; };
class CompareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define REGISTER_LOGICAL_OP(op_type, _equation) \ #define REGISTER_LOGICAL_OP(op_type, _equation) \
struct _##op_type##Comment { \ struct _##op_type##Comment { \
static char type[]; \ static char type[]; \
static char equation[]; \ static char equation[]; \
}; \ }; \
char _##op_type##Comment::type[]{#op_type}; \ char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \ char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OP_WITH_KERNEL( \ REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \ op_type, ::paddle::operators::CompareOp, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \ ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker); ::paddle::framework::EmptyGradOpMaker);
REGISTER_LOGICAL_OP(less_than, "Out = X < Y"); REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
......
...@@ -49,8 +49,6 @@ struct Transform<platform::CPUPlace> { ...@@ -49,8 +49,6 @@ struct Transform<platform::CPUPlace> {
template <typename InputIter, typename OutputIter, typename UnaryOperation> template <typename InputIter, typename OutputIter, typename UnaryOperation>
void operator()(const DeviceContext& context, InputIter first, InputIter last, void operator()(const DeviceContext& context, InputIter first, InputIter last,
OutputIter result, UnaryOperation op) { OutputIter result, UnaryOperation op) {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
std::transform(first, last, result, op); std::transform(first, last, result, op);
} }
...@@ -59,8 +57,6 @@ struct Transform<platform::CPUPlace> { ...@@ -59,8 +57,6 @@ struct Transform<platform::CPUPlace> {
void operator()(const DeviceContext& context, InputIter1 first1, void operator()(const DeviceContext& context, InputIter1 first1,
InputIter1 last1, InputIter2 first2, OutputIter result, InputIter1 last1, InputIter2 first2, OutputIter result,
BinaryOperation op) { BinaryOperation op) {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
std::transform(first1, last1, first2, result, op); std::transform(first1, last1, first2, result, op);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册