未验证 提交 3187451a 编写于 作者: Y Yang Yu

CompareOp's kernel device type is decided by input tensor place

CompareOp can run on CPU even other operators are running on GPU, since
opeatations like comparing control flags should be performed only on CPU
上级 c365c61a
......@@ -14,6 +14,7 @@
#include "paddle/operators/compare_op.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename OpComment>
......@@ -61,6 +62,20 @@ class CompareOpInferShape : public framework::InferShapeBase {
}
};
class CompareOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
};
} // namespace operators
} // namespace paddle
......@@ -71,8 +86,9 @@ class CompareOpInferShape : public framework::InferShapeBase {
}; \
char _##op_type##Comment::type[]{#op_type}; \
char _##op_type##Comment::equation[]{_equation}; \
REGISTER_OP_WITH_KERNEL( \
op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::CompareOp, \
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
::paddle::framework::EmptyGradOpMaker);
......
......@@ -49,8 +49,6 @@ struct Transform<platform::CPUPlace> {
template <typename InputIter, typename OutputIter, typename UnaryOperation>
void operator()(const DeviceContext& context, InputIter first, InputIter last,
OutputIter result, UnaryOperation op) {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
std::transform(first, last, result, op);
}
......@@ -59,8 +57,6 @@ struct Transform<platform::CPUPlace> {
void operator()(const DeviceContext& context, InputIter1 first1,
InputIter1 last1, InputIter2 first2, OutputIter result,
BinaryOperation op) {
auto place = context.GetPlace();
PADDLE_ENFORCE(is_cpu_place(place), "It must use CPU place.");
std::transform(first1, last1, first2, result, op);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册