提交 9a05c907 编写于 作者: T typhoonzero

fix StridedNumelCopyWithAxis

上级 01f4bcb5
......@@ -58,6 +58,7 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
int64_t src_after = src_stride_numel[axis];
int64_t dst_after = dst_stride_numel[axis];
int64_t copy_size = std::min(src_after, dst_after);
auto place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(),
......@@ -82,14 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
src + i * src_after, sizeof(T) * src_after);
src + i * src_after, sizeof(T) * copy_size);
} else {
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
src + i * src_after, sizeof(T) * src_after,
src + i * src_after, sizeof(T) * copy_size,
cuda_ctx.stream());
#else
PADDLE_THROW("Paddle is not compiled with GPU");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册