提交 303fb789 编写于 作者: Q qijun

refine tensor copy from

上级 5364b394
...@@ -83,7 +83,7 @@ inline void Tensor::ShareDataWith(const Tensor& src) { ...@@ -83,7 +83,7 @@ inline void Tensor::ShareDataWith(const Tensor& src) {
template <typename T> template <typename T>
inline void Tensor::CopyFrom(const Tensor& src, inline void Tensor::CopyFrom(const Tensor& src,
const platform::CPUPlace& dst_place) { const platform::Place& dst_place) {
src.check_memory_size<T>(); src.check_memory_size<T>();
Resize(src.dims()); Resize(src.dims());
...@@ -94,41 +94,27 @@ inline void Tensor::CopyFrom(const Tensor& src, ...@@ -94,41 +94,27 @@ inline void Tensor::CopyFrom(const Tensor& src,
auto size = product(src.dims_) * sizeof(T); auto size = product(src.dims_) * sizeof(T);
if (platform::is_cpu_place(src_place)) { if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size); boost::get<platform::CPUPlace>(src_place), src_ptr, size);
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
else if (platform::is_gpu_place(src_place)) { else if (platform::is_gpu_place(src_place) &&
platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0); boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0);
} } else if (platform::is_cpu_place(src_place) &&
#endif platform::is_gpu_place(dst_place)) {
}
#ifndef PADDLE_ONLY_CPU
template <typename T>
inline void Tensor::CopyFrom(const Tensor& src,
const platform::GPUPlace& dst_place) {
src.check_memory_size<T>();
Resize(src.dims());
auto src_place = src.holder_->place();
auto src_ptr = static_cast<const void*>(src.data<T>());
auto dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
auto size = product(src.dims_) * sizeof(T);
if (platform::is_cpu_place(src_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size, 0); boost::get<platform::CPUPlace>(src_place), src_ptr, size, 0);
} else if (platform::is_gpu_place(src_place)) { } else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0); boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0);
} }
}
#endif #endif
}
template <typename T> template <typename T>
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
......
...@@ -94,14 +94,7 @@ class Tensor { ...@@ -94,14 +94,7 @@ class Tensor {
* @note CopyFrom supports CPU <-> GPU, GPU <-> GPU. * @note CopyFrom supports CPU <-> GPU, GPU <-> GPU.
*/ */
template <typename T> template <typename T>
inline void CopyFrom(const Tensor& src, inline void CopyFrom(const Tensor& src, const platform::Place& dst_place);
const platform::CPUDeviceContext& ctx);
#ifndef PADDLE_ONLY_CPU
template <typename T>
inline void CopyFrom(const Tensor& src,
const platform::CUDADeviceContext& ctx);
#endif
/** /**
* @brief Return the slice of the tensor. * @brief Return the slice of the tensor.
......
...@@ -34,7 +34,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place, ...@@ -34,7 +34,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
void* dst, void* dst,
platform::GPUPlace src_place, platform::GPUPlace src_place,
const void* src, size_t num, const void* src, size_t num,
cudaStream_t stream = 0) { cudaStream_t stream) {
platform::SetDeviceId(src_place.device); platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
} }
...@@ -44,7 +44,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place, ...@@ -44,7 +44,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
void* dst, void* dst,
platform::CPUPlace src_place, platform::CPUPlace src_place,
const void* src, size_t num, const void* src, size_t num,
cudaStream_t stream = 0) { cudaStream_t stream) {
platform::SetDeviceId(dst_place.device); platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
} }
...@@ -54,7 +54,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place, ...@@ -54,7 +54,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
void* dst, void* dst,
platform::GPUPlace src_place, platform::GPUPlace src_place,
const void* src, size_t num, const void* src, size_t num,
cudaStream_t stream = 0) { cudaStream_t stream) {
if (dst_place == src_place) { if (dst_place == src_place) {
platform::SetDeviceId(src_place.device); platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
......
...@@ -51,7 +51,7 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); ...@@ -51,7 +51,7 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num);
*/ */
template <typename DstPlace, typename SrcPlace> template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num, void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
cudaStream_t stream = 0); cudaStream_t stream);
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册