提交 f96e2157 编写于 作者: Q QI JUN 提交者: GitHub

Merge pull request #3059 from QiJune/fix_bug_register_OpKernel

fix bug in OpKernel register macro
......@@ -403,15 +403,16 @@ class GradOpRegisterHelper {
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be in global namespace"); \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
struct __op_kernel_register__##type##__##DEVICE_TYPE##__ { \
__op_kernel_register__##type##__##DEVICE_TYPE##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new __VA_ARGS__()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \
static __op_kernel_register__##type##__##DEVICE_TYPE##__ \
__reg_kernel_##type##__##DEVICE_TYPE##__; \
int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; }
// (type, KernelType)
......
......@@ -201,7 +201,9 @@ class OperatorWithKernel : public OperatorBase {
place_ = dev_ctx.GetPlace();
}
bool operator==(const OpKernelKey& o) const { return place_ == o.place_; }
bool operator==(const OpKernelKey& o) const {
return platform::places_are_same_class(place_, o.place_);
}
};
struct OpKernelHash {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册