提交 179b7893 编写于 作者: D dzhwinter

"fix CopyToPeer"

上级 709c157a
......@@ -65,7 +65,7 @@ struct LoD : public std::vector<Vector<size_t>> {
void CopyToPeer(platform::Place place) {
for (auto it = this->begin(); it != this->end(); ++it) {
it->mutable_data(place);
it->CopyToPeer(place);
}
}
};
......
......@@ -82,7 +82,7 @@ inline const T *Vector<T>::data(platform::Place place) const {
if (cuda_ptr_ == nullptr) {
return nullptr;
}
if (platform::is_same_place(place, place_)) {
if (boost::get<platform::CUDAPlace>(place) == place_) {
return static_cast<const T *>(cuda_ptr_.get());
} else {
PADDLE_THROW(
......@@ -99,7 +99,7 @@ inline T *Vector<T>::mutable_data(platform::Place place) {
if (platform::is_cpu_place(place)) {
return std::vector<T>::data();
} else if (platform::is_gpu_place(place)) {
if (!platform::is_same_place(place, place_)) {
if (boost::get<platform::CUDAPlace>(place) != place_) {
place_ = boost::get<platform::CUDAPlace>(place);
}
#ifdef PADDLE_WITH_CUDA
......@@ -159,5 +159,26 @@ void Vector<T>::CopyFromCUDA() {
#endif
}
template <typename T>
void Vector<T>::CopyToPeer(platform::Place place) {
#ifdef PADDLE_WITH_CUDA
if (boost::get<platform::CUDAPlace>(place) != place_) {
place_ = boost::get<platform::CUDAPlace>(place);
}
if (cuda_size_ < this->size() || cuda_ptr_ == nullptr) {
cuda_ptr_.reset(
memory::Alloc<platform::CUDAPlace>(place_, this->size() * sizeof(T)),
memory::PlainDeleter<void, platform::CUDAPlace>(place_));
}
cuda_size_ = this->size();
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *ctx = pool.GetByPlace(place_);
memory::Copy(place_, cuda_ptr_.get(), platform::CPUPlace(),
static_cast<const void *>(this->data()),
this->size() * sizeof(T), ctx->stream());
ctx->Wait();
#endif
}
} // namespace framework
} // namespace paddle
......@@ -79,7 +79,7 @@ inline void CopyOrShare(const framework::Variable &src,
dst->GetMutable<LoDTensor>()->set_lod(src.Get<LoDTensor>().lod());
} else {
Copy(src.Get<LoDTensor>(), dst_place, dst->GetMutable<LoDTensor>());
LoD lod(src.Get<LoDTensor>().lod());
framework::LoD lod(src.Get<LoDTensor>().lod());
lod.CopyToPeer(dst_place);
dst->GetMutable<LoDTensor>()->set_lod(lod);
}
......@@ -92,7 +92,7 @@ inline void CopyOrShare(const framework::Variable &src,
dst_sr->set_rows(src_sr.rows());
} else {
Copy(src_sr.value(), dst_place, dst_sr->mutable_value());
LoD lod(src.Get<LoDTensor>().lod());
framework::Vector<int64_t> lod(src_sr.rows());
lod.CopyToPeer(dst_place);
dst_sr->set_rows(lod);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册