From 2ab12ca2480979ba4d651adff7fef6daf3a86ce7 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 17 Apr 2018 10:26:04 +0800 Subject: [PATCH] Add comments and clean code --- paddle/fluid/pybind/pybind.cc | 6 ++++-- paddle/fluid/pybind/tensor_py.h | 9 +++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 254c4a5942d..5121987f922 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -505,13 +505,15 @@ All parameter, weight, gradient are variables in Paddle. scope, local_scopes, allow_op_delay); }) .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) + // NOTE: even we return a vec* to Python use reference policy. + // We still cannot get local_scope from this vector, since the element + // of vec will be freed by Python GC. We can only return Scope* + // one by one and mark them as reference. .def("local_scopes", [](ParallelExecutor &self) -> std::vector * { return &self.GetLocalScopes(); }, py::return_value_policy::reference) - .def("local_scopes_len", - [](ParallelExecutor &self) { return self.GetLocalScopes().size(); }) .def("local_scope", [](ParallelExecutor &self, size_t i) { return self.GetLocalScopes()[i]; }, py::return_value_policy::reference) diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index c9cad15a74b..159d1d5f4e7 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -190,6 +190,10 @@ void PyCUDATensorSetFromArray( static_cast(pool.Get(place)); paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice, dev_ctx->stream()); + // NOTE: For safety, here wait the copy complete. + // It because the CPU array.data() could be destroyed after this method. + // If we make this method async, it could be copied data from a memory buffer + // that has been freed. dev_ctx->Wait(); } @@ -217,6 +221,11 @@ void PyCUDATensorSetFromArray( paddle::platform::GpuMemcpyAsync(dst, array.data(), sizeof(uint16_t) * array.size(), cudaMemcpyHostToDevice, dev_ctx->stream()); + // NOTE: For safety, here wait the copy complete. + // It because the CPU array.data() could be destroyed after this method. + // If we make this method async, it could be copied data from a memory buffer + // that has been freed. + dev_ctx->Wait(); } template -- GitLab