diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index f16deae028d76dc40d6bc589648b461c430c3c98..384f0f631dd9b9a4dd7c0c628340afe668bc248f 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -403,15 +403,16 @@ class GradOpRegisterHelper { STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op_kernel_##type##_##DEVICE_TYPE##__, \ "REGISTER_OP_KERNEL must be in global namespace"); \ - struct __op_kernel_register__##type##__ { \ - __op_kernel_register__##type##__() { \ + struct __op_kernel_register__##type##__##DEVICE_TYPE##__ { \ + __op_kernel_register__##type##__##DEVICE_TYPE##__() { \ ::paddle::framework::OperatorWithKernel::OpKernelKey key; \ key.place_ = PlaceType(); \ ::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \ .reset(new __VA_ARGS__()); \ } \ }; \ - static __op_kernel_register__##type##__ __reg_kernel_##type##__; \ + static __op_kernel_register__##type##__##DEVICE_TYPE##__ \ + __reg_kernel_##type##__##DEVICE_TYPE##__; \ int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; } // (type, KernelType) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index f59314f8288d37f0c645b99811b1355f9a496c00..97e9ec1bcf86b2104b5ca5ec242b466a9aa960c8 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -199,7 +199,11 @@ class OperatorWithKernel : public OperatorBase { place_ = dev_ctx.GetPlace(); } - bool operator==(const OpKernelKey& o) const { return place_ == o.place_; } + // bool operator==(const OpKernelKey& o) const { return place_ == o.place_; + // } + bool operator==(const OpKernelKey& o) const { + return platform::places_are_same_class(place_, o.place_); + } }; struct OpKernelHash { diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 27a80f7ffa3f52f99ad6d3f5e84b1c40327299e5..1229451523f3bb42cbfebc8b916c4246958125cb 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -80,9 +80,11 @@ PYBIND11_PLUGIN(core) { self.mutable_data(place); }) .def("set", paddle::pybind::PyCPUTensorSetFromArray) - .def("set", paddle::pybind::PyCUDATensorSetFromArray) .def("set", paddle::pybind::PyCPUTensorSetFromArray) +#ifndef PADDLE_ONLY_CPU + .def("set", paddle::pybind::PyCUDATensorSetFromArray) .def("set", paddle::pybind::PyCUDATensorSetFromArray) +#endif .def("shape", [](pd::Tensor& self) { return pd::vectorize(self.dims()); }); diff --git a/paddle/pybind/tensor_bind.h b/paddle/pybind/tensor_bind.h index 86eff97d7224517d5d9a611fd530255d7fe21fcc..def37219ccefd5435f1212c4e4daac5a351d76f4 100644 --- a/paddle/pybind/tensor_bind.h +++ b/paddle/pybind/tensor_bind.h @@ -42,9 +42,6 @@ template struct CastToPyBufferImpl { using CUR_TYPE = typename std::tuple_element>::type; py::buffer_info operator()(framework::Tensor &tensor) { - PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()), - "Only CPU tensor can cast to numpy array"); - if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) { auto dim_vec = framework::vectorize(tensor.dims()); std::vector dims_outside; @@ -99,6 +96,7 @@ void PyCPUTensorSetFromArray( std::memcpy(dst, array.data(), sizeof(T) * array.size()); } +#ifndef PADDLE_ONLY_CPU template void PyCUDATensorSetFromArray( framework::Tensor &self, @@ -112,10 +110,10 @@ void PyCUDATensorSetFromArray( self.Resize(framework::make_ddim(dims)); auto *dst = self.mutable_data(place); - std::memcpy(dst, array.data(), sizeof(T) * array.size()); paddle::platform::GpuMemcpySync( dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice); } +#endif } // namespace pybind } // namespace paddle