diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index f59314f8288d37f0c645b99811b1355f9a496c00..0997983a03877b47a6772affeb294a89ec0e3f1f 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -199,7 +199,9 @@ class OperatorWithKernel : public OperatorBase { place_ = dev_ctx.GetPlace(); } - bool operator==(const OpKernelKey& o) const { return place_ == o.place_; } + bool operator==(const OpKernelKey& o) const { + return platform::places_are_same_class(place_, o.place_); + } }; struct OpKernelHash {