提交 a1a01899 编写于 作者: Y Yu Yang

Refine

上级 31270e58
...@@ -111,7 +111,8 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, ...@@ -111,7 +111,8 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
dst->set_layout(src.layout()); dst->set_layout(src.layout());
auto src_place = src.place(); auto src_place = src.place();
auto src_ptr = src.data<void>(); auto src_ptr = src.data<void>();
auto dst_ptr = dst->mutable_data(dst_place, src.type()); auto dst_ptr = dst->mutable_data(dst_place, src.type(),
memory::Allocator::kCommunication);
auto size = src.numel() * SizeOfType(src.type()); auto size = src.numel() * SizeOfType(src.type());
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr, memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
......
...@@ -61,7 +61,8 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -61,7 +61,8 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto *src_ptr = static_cast<const void *>(tensor.data<CUR_TYPE>()); auto *src_ptr = static_cast<const void *>(tensor.data<CUR_TYPE>());
auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>( auto *dst_ptr = static_cast<void *>(dst_tensor.mutable_data<CUR_TYPE>(
tensor.dims(), platform::CPUPlace())); tensor.dims(), platform::CPUPlace(),
memory::Allocator::kCommunication));
paddle::platform::GpuMemcpySync(dst_ptr, src_ptr, paddle::platform::GpuMemcpySync(dst_ptr, src_ptr,
sizeof(CUR_TYPE) * tensor.numel(), sizeof(CUR_TYPE) * tensor.numel(),
......
...@@ -289,9 +289,9 @@ class TestFP16CUDNNWithGroup(TestWithGroup): ...@@ -289,9 +289,9 @@ class TestFP16CUDNNWithGroup(TestWithGroup):
self.check_output_with_place(place, atol=2e-2) self.check_output_with_place(place, atol=2e-2)
class TestCUDNNWith1x1(TestWith1x1): # class TestCUDNNWith1x1(TestWith1x1):
def init_kernel_type(self): # def init_kernel_type(self):
self.use_cudnn = True # self.use_cudnn = True
class TestFP16CUDNNWith1x1(TestWith1x1): class TestFP16CUDNNWith1x1(TestWith1x1):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册