提交 303fb789 编写于 作者: Q qijun

refine tensor copy from

上级 5364b394
......@@ -83,7 +83,7 @@ inline void Tensor::ShareDataWith(const Tensor& src) {
template <typename T>
inline void Tensor::CopyFrom(const Tensor& src,
const platform::CPUPlace& dst_place) {
const platform::Place& dst_place) {
src.check_memory_size<T>();
Resize(src.dims());
......@@ -94,41 +94,27 @@ inline void Tensor::CopyFrom(const Tensor& src,
auto size = product(src.dims_) * sizeof(T);
if (platform::is_cpu_place(src_place)) {
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
#ifndef PADDLE_ONLY_CPU
else if (platform::is_gpu_place(src_place)) {
else if (platform::is_gpu_place(src_place) &&
platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0);
}
#endif
}
#ifndef PADDLE_ONLY_CPU
template <typename T>
inline void Tensor::CopyFrom(const Tensor& src,
const platform::GPUPlace& dst_place) {
src.check_memory_size<T>();
Resize(src.dims());
auto src_place = src.holder_->place();
auto src_ptr = static_cast<const void*>(src.data<T>());
auto dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
auto size = product(src.dims_) * sizeof(T);
if (platform::is_cpu_place(src_place)) {
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size, 0);
} else if (platform::is_gpu_place(src_place)) {
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0);
}
}
#endif
}
template <typename T>
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
......
......@@ -94,14 +94,7 @@ class Tensor {
* @note CopyFrom supports CPU <-> GPU, GPU <-> GPU.
*/
template <typename T>
inline void CopyFrom(const Tensor& src,
const platform::CPUDeviceContext& ctx);
#ifndef PADDLE_ONLY_CPU
template <typename T>
inline void CopyFrom(const Tensor& src,
const platform::CUDADeviceContext& ctx);
#endif
inline void CopyFrom(const Tensor& src, const platform::Place& dst_place);
/**
* @brief Return the slice of the tensor.
......
......@@ -34,7 +34,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
void* dst,
platform::GPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream = 0) {
cudaStream_t stream) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
}
......@@ -44,7 +44,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
void* dst,
platform::CPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream = 0) {
cudaStream_t stream) {
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
}
......@@ -54,7 +54,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
void* dst,
platform::GPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream = 0) {
cudaStream_t stream) {
if (dst_place == src_place) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
......
......@@ -51,7 +51,7 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num);
*/
template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
cudaStream_t stream = 0);
cudaStream_t stream);
#endif // PADDLE_ONLY_CPU
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册