提交 179b7893 编写于 作者: D dzhwinter

"fix CopyToPeer"

上级 709c157a
...@@ -65,7 +65,7 @@ struct LoD : public std::vector<Vector<size_t>> { ...@@ -65,7 +65,7 @@ struct LoD : public std::vector<Vector<size_t>> {
void CopyToPeer(platform::Place place) { void CopyToPeer(platform::Place place) {
for (auto it = this->begin(); it != this->end(); ++it) { for (auto it = this->begin(); it != this->end(); ++it) {
it->mutable_data(place); it->CopyToPeer(place);
} }
} }
}; };
......
...@@ -82,7 +82,7 @@ inline const T *Vector<T>::data(platform::Place place) const { ...@@ -82,7 +82,7 @@ inline const T *Vector<T>::data(platform::Place place) const {
if (cuda_ptr_ == nullptr) { if (cuda_ptr_ == nullptr) {
return nullptr; return nullptr;
} }
if (platform::is_same_place(place, place_)) { if (boost::get<platform::CUDAPlace>(place) == place_) {
return static_cast<const T *>(cuda_ptr_.get()); return static_cast<const T *>(cuda_ptr_.get());
} else { } else {
PADDLE_THROW( PADDLE_THROW(
...@@ -99,7 +99,7 @@ inline T *Vector<T>::mutable_data(platform::Place place) { ...@@ -99,7 +99,7 @@ inline T *Vector<T>::mutable_data(platform::Place place) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
return std::vector<T>::data(); return std::vector<T>::data();
} else if (platform::is_gpu_place(place)) { } else if (platform::is_gpu_place(place)) {
if (!platform::is_same_place(place, place_)) { if (boost::get<platform::CUDAPlace>(place) != place_) {
place_ = boost::get<platform::CUDAPlace>(place); place_ = boost::get<platform::CUDAPlace>(place);
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -159,5 +159,26 @@ void Vector<T>::CopyFromCUDA() { ...@@ -159,5 +159,26 @@ void Vector<T>::CopyFromCUDA() {
#endif #endif
} }
template <typename T>
void Vector<T>::CopyToPeer(platform::Place place) {
#ifdef PADDLE_WITH_CUDA
if (boost::get<platform::CUDAPlace>(place) != place_) {
place_ = boost::get<platform::CUDAPlace>(place);
}
if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) {
cuda_ptr_.reset(
memory::Alloc<platform::CUDAPlace>(place_, this->size() * sizeof(T)),
memory::PlainDeleter<void, platform::CUDAPlace>(place_));
}
cuda_size_ = this->size();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *ctx = pool.GetByPlace(place_);
memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(),
static_cast<const void *>(this->data()),
this->size() * sizeof(T), ctx->stream());
ctx->Wait();
#endif
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -79,7 +79,7 @@ inline void CopyOrShare(const framework::Variable &src, ...@@ -79,7 +79,7 @@ inline void CopyOrShare(const framework::Variable &src,
dst->GetMutable<LoDTensor>()->set_lod(src.Get<LoDTensor>().lod()); dst->GetMutable<LoDTensor>()->set_lod(src.Get<LoDTensor>().lod());
} else { } else {
Copy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>()); Copy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>());
LoD lod(src.Get<LoDTensor>().lod()); framework::LoD lod(src.Get<LoDTensor>().lod());
lod.CopyToPeer(dst_place); lod.CopyToPeer(dst_place);
dst->GetMutable<LoDTensor>()->set_lod(lod); dst->GetMutable<LoDTensor>()->set_lod(lod);
} }
...@@ -92,7 +92,7 @@ inline void CopyOrShare(const framework::Variable &src, ...@@ -92,7 +92,7 @@ inline void CopyOrShare(const framework::Variable &src,
dst_sr->set_rows(src_sr.rows()); dst_sr->set_rows(src_sr.rows());
} else { } else {
Copy(src_sr.value(), dst_place, dst_sr->mutable_value()); Copy(src_sr.value(), dst_place, dst_sr->mutable_value());
LoD lod(src.Get<LoDTensor>().lod()); framework::Vector<int64_t> lod(src_sr.rows());
lod.CopyToPeer(dst_place); lod.CopyToPeer(dst_place);
dst_sr->set_rows(lod); dst_sr->set_rows(lod);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册