From 4ecf68e0ea08b71fc061b1104ffeb225592b280d Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 25 Jul 2017 15:58:09 +0000 Subject: [PATCH] fix bug in register gpu OpKernel --- paddle/framework/op_registry.h | 7 ++++--- paddle/framework/operator.h | 6 +++++- paddle/pybind/pybind.cc | 4 +++- paddle/pybind/tensor_bind.h | 6 ++---- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index f16deae028d..384f0f631dd 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -403,15 +403,16 @@ class GradOpRegisterHelper { STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op_kernel_##type##_##DEVICE_TYPE##__, \ "REGISTER_OP_KERNEL must be in global namespace"); \ - struct __op_kernel_register__##type##__ { \ - __op_kernel_register__##type##__() { \ + struct __op_kernel_register__##type##__##DEVICE_TYPE##__ { \ + __op_kernel_register__##type##__##DEVICE_TYPE##__() { \ ::paddle::framework::OperatorWithKernel::OpKernelKey key; \ key.place_ = PlaceType(); \ ::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \ .reset(new __VA_ARGS__()); \ } \ }; \ - static __op_kernel_register__##type##__ __reg_kernel_##type##__; \ + static __op_kernel_register__##type##__##DEVICE_TYPE##__ \ + __reg_kernel_##type##__##DEVICE_TYPE##__; \ int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; } // (type, KernelType) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index f59314f8288..97e9ec1bcf8 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -199,7 +199,11 @@ class OperatorWithKernel : public OperatorBase { place_ = dev_ctx.GetPlace(); } - bool operator==(const OpKernelKey& o) const { return place_ == o.place_; } + // bool operator==(const OpKernelKey& o) const { return place_ == o.place_; + // } + bool operator==(const OpKernelKey& o) const { + return platform::places_are_same_class(place_, o.place_); + } }; struct OpKernelHash { diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 27a80f7ffa3..1229451523f 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -80,9 +80,11 @@ PYBIND11_PLUGIN(core) { self.mutable_data(place); }) .def("set", paddle::pybind::PyCPUTensorSetFromArray) - .def("set", paddle::pybind::PyCUDATensorSetFromArray) .def("set", paddle::pybind::PyCPUTensorSetFromArray) +#ifndef PADDLE_ONLY_CPU + .def("set", paddle::pybind::PyCUDATensorSetFromArray) .def("set", paddle::pybind::PyCUDATensorSetFromArray) +#endif .def("shape", [](pd::Tensor& self) { return pd::vectorize(self.dims()); }); diff --git a/paddle/pybind/tensor_bind.h b/paddle/pybind/tensor_bind.h index 86eff97d722..def37219cce 100644 --- a/paddle/pybind/tensor_bind.h +++ b/paddle/pybind/tensor_bind.h @@ -42,9 +42,6 @@ template struct CastToPyBufferImpl { using CUR_TYPE = typename std::tuple_element>::type; py::buffer_info operator()(framework::Tensor &tensor) { - PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()), - "Only CPU tensor can cast to numpy array"); - if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) { auto dim_vec = framework::vectorize(tensor.dims()); std::vector dims_outside; @@ -99,6 +96,7 @@ void PyCPUTensorSetFromArray( std::memcpy(dst, array.data(), sizeof(T) * array.size()); } +#ifndef PADDLE_ONLY_CPU template void PyCUDATensorSetFromArray( framework::Tensor &self, @@ -112,10 +110,10 @@ void PyCUDATensorSetFromArray( self.Resize(framework::make_ddim(dims)); auto *dst = self.mutable_data(place); - std::memcpy(dst, array.data(), sizeof(T) * array.size()); paddle::platform::GpuMemcpySync( dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice); } +#endif } // namespace pybind } // namespace paddle -- GitLab