提交 4ecf68e0 编写于 作者: Q qijun

fix bug in register gpu OpKernel

上级 358261f0
...@@ -403,15 +403,16 @@ class GradOpRegisterHelper { ...@@ -403,15 +403,16 @@ class GradOpRegisterHelper {
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op_kernel_##type##_##DEVICE_TYPE##__, \ __reg_op_kernel_##type##_##DEVICE_TYPE##__, \
"REGISTER_OP_KERNEL must be in global namespace"); \ "REGISTER_OP_KERNEL must be in global namespace"); \
struct __op_kernel_register__##type##__ { \ struct __op_kernel_register__##type##__##DEVICE_TYPE##__ { \
__op_kernel_register__##type##__() { \ __op_kernel_register__##type##__##DEVICE_TYPE##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \ ::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \ key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \ ::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new __VA_ARGS__()); \ .reset(new __VA_ARGS__()); \
} \ } \
}; \ }; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \ static __op_kernel_register__##type##__##DEVICE_TYPE##__ \
__reg_kernel_##type##__##DEVICE_TYPE##__; \
int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; } int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; }
// (type, KernelType) // (type, KernelType)
......
...@@ -199,7 +199,11 @@ class OperatorWithKernel : public OperatorBase { ...@@ -199,7 +199,11 @@ class OperatorWithKernel : public OperatorBase {
place_ = dev_ctx.GetPlace(); place_ = dev_ctx.GetPlace();
} }
bool operator==(const OpKernelKey& o) const { return place_ == o.place_; } // bool operator==(const OpKernelKey& o) const { return place_ == o.place_;
// }
bool operator==(const OpKernelKey& o) const {
return platform::places_are_same_class(place_, o.place_);
}
}; };
struct OpKernelHash { struct OpKernelHash {
......
...@@ -80,9 +80,11 @@ PYBIND11_PLUGIN(core) { ...@@ -80,9 +80,11 @@ PYBIND11_PLUGIN(core) {
self.mutable_data<int>(place); self.mutable_data<int>(place);
}) })
.def("set", paddle::pybind::PyCPUTensorSetFromArray<float>) .def("set", paddle::pybind::PyCPUTensorSetFromArray<float>)
.def("set", paddle::pybind::PyCUDATensorSetFromArray<float>)
.def("set", paddle::pybind::PyCPUTensorSetFromArray<int>) .def("set", paddle::pybind::PyCPUTensorSetFromArray<int>)
#ifndef PADDLE_ONLY_CPU
.def("set", paddle::pybind::PyCUDATensorSetFromArray<float>)
.def("set", paddle::pybind::PyCUDATensorSetFromArray<int>) .def("set", paddle::pybind::PyCUDATensorSetFromArray<int>)
#endif
.def("shape", .def("shape",
[](pd::Tensor& self) { return pd::vectorize(self.dims()); }); [](pd::Tensor& self) { return pd::vectorize(self.dims()); });
......
...@@ -42,9 +42,6 @@ template <size_t I, typename... ARGS> ...@@ -42,9 +42,6 @@ template <size_t I, typename... ARGS>
struct CastToPyBufferImpl<true, I, ARGS...> { struct CastToPyBufferImpl<true, I, ARGS...> {
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type; using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
py::buffer_info operator()(framework::Tensor &tensor) { py::buffer_info operator()(framework::Tensor &tensor) {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(tensor.holder_->place()),
"Only CPU tensor can cast to numpy array");
if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) { if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) {
auto dim_vec = framework::vectorize(tensor.dims()); auto dim_vec = framework::vectorize(tensor.dims());
std::vector<size_t> dims_outside; std::vector<size_t> dims_outside;
...@@ -99,6 +96,7 @@ void PyCPUTensorSetFromArray( ...@@ -99,6 +96,7 @@ void PyCPUTensorSetFromArray(
std::memcpy(dst, array.data(), sizeof(T) * array.size()); std::memcpy(dst, array.data(), sizeof(T) * array.size());
} }
#ifndef PADDLE_ONLY_CPU
template <typename T> template <typename T>
void PyCUDATensorSetFromArray( void PyCUDATensorSetFromArray(
framework::Tensor &self, framework::Tensor &self,
...@@ -112,10 +110,10 @@ void PyCUDATensorSetFromArray( ...@@ -112,10 +110,10 @@ void PyCUDATensorSetFromArray(
self.Resize(framework::make_ddim(dims)); self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place); auto *dst = self.mutable_data<T>(place);
std::memcpy(dst, array.data(), sizeof(T) * array.size());
paddle::platform::GpuMemcpySync( paddle::platform::GpuMemcpySync(
dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice); dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice);
} }
#endif
} // namespace pybind } // namespace pybind
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册